#ifndef _LinearSystem_H
#define _LinearSystem_H

/* for time measurements */
#include <time.h>
/* for vector template */
#include <vector>
#include <list>
#include <set>
#include "Q.h"
#include "Z.h"

#include <fstream>

/**
 * Creates a linear system Ax = b, where A is in compressed column form.
 * Uses gaussian elimination and exact arithmetic to solve.
 * The linear system is given in the form [A|-b]. i.e. the last column of the
 * constraint matrix is assumed to be the right hand side.
 */
template <class K>
class LinearSystem 
{
    private:
        /** number of columns in A */ 
        int n;
        /** number of rows in A */
        int m;

        // An exponent.
        struct Exp {
            public:
                Exp() : i(0), e(0) {}
                Exp(int index, K exponent) : i(index), e(exponent) {}
                bool operator<(const Exp& e) const { return i < e.i; }
                int i; K e;
        };
        typedef std::vector<Exp> Row;
        typedef std::vector<Row> Matrix;
        Matrix matrix;

        /** solution vector */
        std::vector<K> soln;

        /** pivot rows */
        std::vector<int> pivots;
        void triangulate();

        void reduce_row(int i, int j); // Add k times row i to j.
        void reduce_row_reverse(int i, int j); // Add k times row i to j.
        void normalise_row(int i); // Remove any factors.

        void swap_rows(int i, int j); // Swap rows i and j.

        /** if consistent, find actual solution */
        void back_solve1(); 
        void back_solve2(); 

        /** Checks whether the current solution is in fact a solution of the
         * linear system. **/
        void check_solution();

    public:

        /** preprocessing */
        void preprocessMatrix();

        /** public constructor */
        LinearSystem(int num_col);

        /** public destructor */
        virtual ~LinearSystem();

        /** Get the solution vector. */
        const std::vector<K>& getSolution();

        /** add entry (i,j) to the constraint matrix. */
        void addEntry(int i, int j, K v);

        /** add entry i to the right hand side. */
        void addRHSEntry(int i, K v);

        /* get number of rows in matrix */ 
        int getNumRows() const {return m;}
		/* set the number of rows in matrix */ 
        void setNumRows(Index s) { m = s; matrix.resize(s); }
        /* get number of columns in matrix (not including rhs) */ 
        int getNumCols() const { return n-1; }
        /* set number of columns in matrix (not including rhs) */ 
        void setNumCols(int _n) { n = _n+1; }
        // Get the number of non zero entires.
        long int getNumNonZeros() const;

        /** determine if the linear system is consistent */
        bool isConsistent();

        /** determine the rank of the linear system. */
		int rank();
		int rank_reverse();
		int rank_reverse(Index cutoff);

        /** Print the linear system. */
        void print() const;
        void print_row(Index i) const;
    
        /** Print the solution of the linear system. */
        void printSolution() const;

        int get_codim(Index start, Index end) const;

        class RowIter
        {
        public:
            RowIter(const RowIter& ri) { it = ri.it; }
            RowIter& operator=(const RowIter& ri) { it = ri.it; return *this; }
            void operator++() { ++it; }
            bool operator!=(const RowIter& i) { return it != i.it; }
            bool operator==(const RowIter& i) { return it == i.it; }
            const Index& index() { return it->i; }
            const K& coeff() { return it->e; }
        protected:
            typename Row::const_iterator it;
            RowIter(typename Row::const_iterator _it) { it = _it; }
            friend class LinearSystem<K>;
        };

        RowIter begin(Index i) const { return RowIter(matrix[i].begin()); }
        RowIter end(Index i) const { return RowIter(matrix[i].end()); }
        RowIter last(Index i) const { return RowIter(--matrix[i].end()); }
};

template <class K>
inline void
LinearSystem<K>::swap_rows(int i, int j)
{
    matrix[i].swap(matrix[j]);
}

template <class K>
LinearSystem<K>::LinearSystem(int _n)
    : n(_n), m(0)
{
    assert(0 <= _n);
}

template <class K>
LinearSystem<K>::~LinearSystem()
{
}

/**
 * Add an entry to the constraint matrix.
 * This function assumes that the same entry is not added twice.
 */
template <class K>
void
LinearSystem<K>::addEntry(int i, int j, K v)
{
    if (v == 0) { return; }
    assert(0 <= j && j < n);
    if (i >= m) {
        matrix.resize(i+1);
        m = i+1;
    }
    matrix[i].push_back(Exp(j,v));
}

/**
 * Add an entry to the right hand side.
 * This function assumes that the same entry is not added twice.
 */
template <class K>
void
LinearSystem<K>::addRHSEntry(int i, K v)
{
    if (v == 0) { return; }
    addEntry(i, n-1, v);
}


template <class K>
long int
LinearSystem<K>::getNumNonZeros() const
{
    long int num = 0;
    for (typename Matrix::const_iterator i = matrix.begin(); 
                    i != matrix.end(); ++i) {
        num += i->size();
    }
    return num;
}

/**
 * Sort all of the entries in the rows.
 */
template <class K>
void
LinearSystem<K>::preprocessMatrix()
{
    for (int r = 0; r < m; ++r) {
        sort(matrix[r].begin(), matrix[r].end());
    }
}

template <class K>
void
LinearSystem<K>::reduce_row(int i, int j)
{
    assert(0 <= i && i < m);
    assert(0 <= j && j < m);
    Row& ri = matrix[i];
    Row& rj = matrix[j];
    //std::cout << "RI = "; print_row(i);
    //std::cout << "RJ = "; print_row(j);
    // TODO: Not Thread Safe.
    static Row r;
    r.reserve(ri.size() + rj.size());
    r.resize(0);

    K m(-rj.front().e/ri.front().e);

    typename Row::iterator iti = ri.begin();
    typename Row::iterator itj = rj.begin();
    while (iti != ri.end() && itj != rj.end()) {
        if (iti->i < itj->i) {
            r.push_back(Exp(iti->i, m*(iti->e)));
            ++iti;
        }
        else if (itj->i < iti->i) {
            r.push_back(*itj);
            ++itj;
        }
        else {
            itj->e += m*(iti->e);
            if (itj->e != 0) { r.push_back(*itj); }
            ++iti; ++itj;
        }
    }

    r.insert(r.end(), itj, rj.end());
    rj = r;
    //rj.insert(rj.end(), iti, ri.end());
    while (iti != ri.end()) {
        rj.push_back(Exp(iti->i, m*(iti->e)));
        ++iti;
    }
    normalise_row(j);
    //std::cout << "RJ*= "; print_row(j);
}

template <class K>
void
LinearSystem<K>::reduce_row_reverse(int i, int j)
{
    assert(0 <= i && i < m);
    assert(0 <= j && j < m);
    Row& ri = matrix[i];
    Row& rj = matrix[j];
    //std::cout << "RI = "; print_row(i);
    //std::cout << "RJ = "; print_row(j);
    // TODO: Not Thread Safe.
    static Row r;
    r.reserve(ri.size() + rj.size());
    r.resize(0);

    K m(-rj.back().e/ri.back().e);

    typename Row::iterator iti = ri.begin();
    typename Row::iterator itj = rj.begin();
    while (iti != ri.end() && itj != rj.end()) {
        if (iti->i < itj->i) {
            r.push_back(Exp(iti->i, m*(iti->e)));
            ++iti;
        }
        else if (itj->i < iti->i) {
            r.push_back(*itj);
            ++itj;
        }
        else {
            itj->e += m*(iti->e);
            if (itj->e != 0) { r.push_back(*itj); }
            ++iti; ++itj;
        }
    }

    r.insert(r.end(), itj, rj.end());
    rj = r;
    while (iti != ri.end()) {
        rj.push_back(Exp(iti->i, m*(iti->e)));
        ++iti;
    }
    normalise_row(j);
}

template <class K>
inline void
LinearSystem<K>::normalise_row(int ) {}

template <>
inline void
LinearSystem<Q>::normalise_row(int i)
{
    Row& ri = matrix[i];
    if (ri.empty()) { return; }
    Z gcd_num = ri[0].e.get_num();
    Z gcd_den = ri[0].e.get_den();
    for (size_t j = 1; j < ri.size(); ++j) {
       if (gcd_num == 1) { break; }
        mpz_gcd(gcd_num.get_mpz_t(), ri[j].e.get_num_mpz_t(), gcd_num.get_mpz_t());
    }
    for (size_t j = 1; j < ri.size(); ++j) {
       if (gcd_den == 1) { break; }
        mpz_gcd(gcd_den.get_mpz_t(), ri[j].e.get_den_mpz_t(), gcd_den.get_mpz_t());
    }
    if (gcd_num != 1) {
        for (size_t j = 0; j < ri.size(); ++j) { ri[j].e.get_num() /= gcd_num; }
    }
    if (gcd_den != 1) {
        for (size_t j = 0; j < ri.size(); ++j) { ri[j].e.get_den() /= gcd_den; }
    }
}

template <class K>
bool
LinearSystem<K>::isConsistent()
{
    preprocessMatrix();
 
    std::vector<std::vector<int> > col_entries(n);
    for (int r = 0; r < m; ++r) {
        if (!matrix[r].empty()) {
            if (matrix[r].front().i == n-1) { return false; }
            col_entries[matrix[r].front().i].push_back(r);
        }
    }

    for (int c = 0; c < n; ++c) {
        // Choose the next row.
        std::vector<int>& col = col_entries[c];
        if (col.empty()) { continue; }
        int next = col.front();
        for (int i = 1; i < (int) col.size(); ++i) {
                if (matrix[col[i]].size() < matrix[next].size()) { next = col[i]; }
        }

        pivots.push_back(next);

        for (int i = 0; i < (int) col.size(); ++i) {
            int r = col[i];
            if (r != next) {
                reduce_row(next, r);
                if (!matrix[r].empty()) {
                    // Check if system is inconsistent.
                    if (matrix[r].front().i == n-1) { return false; }
                    // Update first non-zero entries data structure.
                    col_entries[matrix[r].front().i].push_back(r);
                }
            }
        }
    }
    return true;
}

/**
 * Check whether the linear system is consistent.
 * This function behaves incorrectly when initially there is an empty row with
 * a non-zero RHS.
 */
template <class K>
int
LinearSystem<K>::rank()
{
	preprocessMatrix();
    std::vector<std::vector<int> > col_entries(n);
    for (int r = 0; r < m; ++r) {
        if (!matrix[r].empty()) {
            col_entries[matrix[r].front().i].push_back(r);
        }
    }

    for (int c = 0; c < n; ++c) {
        // Choose the next row.
        std::vector<int>& col = col_entries[c];
        if (col.empty()) { continue; }
        int next = col.front();
        for (int i = 1; i < (int) col.size(); ++i) {
            if (matrix[col[i]].size() < matrix[next].size()) { next = col[i]; }
        }

        pivots.push_back(next);

        for (int i = 0; i < (int) col.size(); ++i) {
            int r = col[i];
            if (r != next) {
                reduce_row(next, r);
                if (!matrix[r].empty()) {
                    // Update first non-zero entries data structure.
                    col_entries[matrix[r].front().i].push_back(r);
                }
            }
        }
    }
    return pivots.size();
}

/**
 * Check whether the linear system is consistent.
 * This function behaves incorrectly when initially there is an empty row with
 * a non-zero RHS.
 */
template <class K>
int
LinearSystem<K>::rank_reverse()
{
    return rank_reverse(0);
}

/**
 * Check whether the linear system is consistent.
 * This function behaves incorrectly when initially there is an empty row with
 * a non-zero RHS.
 */
template <class K>
int
LinearSystem<K>::rank_reverse(Index cutoff)
{
	preprocessMatrix();
    //print();
    std::vector<std::vector<int> > col_entries(n);
    for (int r = 0; r < m; ++r) {
        if (!matrix[r].empty()) {
            col_entries[matrix[r].back().i].push_back(r);
        }
    }

    for (int c = n-1; c >= 0; --c) {
        // Choose the next row.
        std::vector<int>& col = col_entries[c];
        if (col.empty()) { continue; }
        int next = col.front();
        if (next >= cutoff) {
            for (int i = 1; i < (int) col.size(); ++i) {
                if (matrix[col[i]].size() < matrix[next].size()) { next = col[i]; }
            }
        }

        for (int i = 0; i < (int) col.size(); ++i) {
            int r = col[i];
            if (r != next) {
                reduce_row_reverse(next, r);
                if (!matrix[r].empty()) {
                    // Update last non-zero entries data structure.
                    col_entries[matrix[r].back().i].push_back(r);
                }
            }
        }
        //print();
    }
    pivots.clear();
    // We now remove empty rows from the matrix.
    Index j = 0;
    for (size_t i = 0; i < matrix.size(); ++i) {
        if (!matrix[i].empty()) { matrix[i].swap(matrix[j]); ++j; }
    }
    matrix.erase(matrix.begin()+j, matrix.end());
    m = j;

    return m;
}


// Assumes that rank_reverse has just been called.
template <class K>
int
LinearSystem<K>::get_codim(Index start, Index end) const
{
    int dim = 0;
    for (size_t i = 0; i < matrix.size(); ++i) {
        if (matrix[i].back().i >= start && matrix[i].back().i < end) { ++dim; }
    }
    return (end-start) - dim;
}


/**
 * Computes the solution of the linear system.
 */
template <class K>
const std::vector<K>&
LinearSystem<K>::getSolution()
{
    back_solve2();
    check_solution();
    return soln;
}

/**
 * Computes the solution of the linear system.
 * This function assumes that the linear system is consistent and
 * that the matrix is in upper triangle form with no zero rows.
 */
template <class K>
void
LinearSystem<K>::back_solve2()
{
    soln.clear();
    soln.resize(n, 0);
    soln[n-1] = 1;
    for (std::vector<int>::reverse_iterator i = pivots.rbegin(); i != pivots.rend(); ++i) {
        Row& r = matrix[*i];
        int c = r.front().i;
        for (int j = 1; j < (int) r.size(); ++j) {
            soln[c] -= soln[r[j].i]*r[j].e;
        }
        assert(r[0].e != 0);
        soln[c] /= r[0].e;
    }
    soln.resize(n-1);
}

template <class K>
void
LinearSystem<K>::check_solution()
{
    soln.resize(n,1);
    for (int i = 0; i < m; ++i) {
        K v = 0;
        for (typename Row::iterator it = matrix[i].begin(); it != matrix[i].end(); ++it) {
            v += soln[it->i]*it->e;
        }
        if (v != 0) {
           std::cerr << "ERROR: Check solution failed!\n";
           exit(1);
        }
    }
    soln.resize(n-1);
}

template <class K>
void
LinearSystem<K>::printSolution() const
{
    for (size_t i = 0; i < soln.size(); ++i) { std::cout << " " << soln[i]; }
    std::cout << "\n";
}

template <class K>
void
LinearSystem<K>::print() const
{
    for (int i = 0; i < m; ++i) { print_row(i); }
     std::cout << std::endl;
}

template <class K>
void
LinearSystem<K>::print_row(Index i) const
{
    const Row& r = matrix[i];
    typename Row::const_iterator itr = r.begin();
    for (int j = 0; j < n; ++j) {
         if (itr == r.end() || j < itr->i) { std::cout << " 0"; }
         else { std::cout << " " << itr->e; ++itr; }
    }
    std::cout << "\n";
}

#endif

