/*
 * Copyright 2025-present Solver4J
 *
 * This work is licensed under the Creative Commons Attribution-NoDerivatives 4.0 
 * International License. To view a copy of this license, visit 
 *
 *        http://creativecommons.org/licenses/by-nd/4.0/ 
 *
 * or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
 */
package com.solver4j.util;

import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;

import cern.colt.function.IntIntDoubleFunction;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.emory.mathcs.csparsej.tdouble.Dcs_common;

public class ColtUtilsTest extends Solver4JBaseTest {

	private Logger logger = LoggerFactory.getLogger(this.getClass().getName());
	
	/**
	 * Use of a Colt sparse matrix
	 */
	public void testDumpSparseMatrix() {
		logger.debug("testDumpSparseMatrix");
		final double[][] A = new double[][]{
				{1, 0, 0, 2},
				{0, 0, 2, 0},
				{2, 3, 0, 0},
				{0, 0, 4, 4}
		};
		
//		SparseCCDoubleMatrix2D S1 = new SparseCCDoubleMatrix2D(4, 4);
//		ColtUtils.dumpSparseMatrix(S1);
		
		SparseDoubleMatrix2D S2 = new SparseDoubleMatrix2D(A);
		//ColtUtils.dumpSparseMatrix(S2);
		S2.forEachNonZero(new IntIntDoubleFunction() {
			public double apply(int i, int j, double sij) {
				assertEquals(sij, A[i][j]);
				return sij;
			}
		});
	}
	
	/**
	 * Use of a Colt sparse matrix
	 */
	public void testDumpSparseMatrix2() {
		logger.debug("testDumpSparseMatrix2");
		final double[][] A = new double[][]{
				{1, 0, 0, 2},
				{0, 0, 2, 0},
				{2, 3, 0, 0},
				{0, 0, 4, 4}
		};
				
		SparseDoubleMatrix2D S2 = new SparseDoubleMatrix2D(A);
		logger.debug("S: " + ArrayUtils.toString(S2.toArray()));
		
		DoubleMatrix2D R = S2.viewPart(0, 0, 1, 4);
		R.forEachNonZero(new IntIntDoubleFunction() {
			public double apply(int i, int j, double sij) {
				logger.debug("i:" + i + ", j:" + j + ": " + sij);
				return sij;
			}
		});
		R.setQuick(0, 1, 7.0);
		logger.debug("S: " + ArrayUtils.toString(S2.toArray()));
		assertEquals(7.0, S2.getQuick(0, 1));//the change on R is also in S2
	}
	
	public void testSubdiagonalMultiply() {
		logger.debug("testSubdiagonalMultiply");
		double[][] A = {{1, 2, 3, 4}, {5, 6, 7, 8}, {1, 3, 5, 7}};
		double[][] B = {{1, 2, 3}, {3, 4, 2}, {5, 6, 7}, {7, 8, 9}};
		//double[][] expectedResult = {{50, 60, 64}, {114, 140, 148}, {84, 100, 107}}; 
		double[][] expectedResult = {
				{50, 0, 0}, 
				{114, 140, 0}, 
				{84, 100, 107}};
		
		// with sparsity
		DoubleMatrix2D ASparse = DoubleFactory2D.sparse.make(A);
		DoubleMatrix2D BSparse = DoubleFactory2D.sparse.make(B);
		DoubleMatrix2D ret1 = ColtUtils.subdiagonalMultiply(ASparse, BSparse);
		logger.debug("ret1: " + ArrayUtils.toString(ret1.toArray()));
		for (int i = 0; i < expectedResult.length; i++) {
			for (int j = 0; j < expectedResult[i].length; j++) {
				assertEquals(expectedResult[i][j], ret1.getQuick(i, j));
			}
		}

		// with no sparsity
		DoubleMatrix2D ADense = DoubleFactory2D.dense.make(A);
		DoubleMatrix2D BDense = DoubleFactory2D.dense.make(B);
		DoubleMatrix2D ret2 = ColtUtils.subdiagonalMultiply(ADense, BDense);
		logger.debug("ret2: " + ArrayUtils.toString(ret2.toArray()));
		for (int i = 0; i < expectedResult.length; i++) {
			for (int j = 0; j < expectedResult[i].length; j++) {
				assertEquals(expectedResult[i][j], ret1.getQuick(i, j));
			}
		}
	}
	
	/**
	 * Manually compose a Dcs representation of a sparse matrix
	 */
	public void testMatrixToDcs() {
		logger.debug("testMatrixToDcs");
		final double[][] A = new double[][]{
				{1, 0, 0, 2},
				{0, 0, 2, 0},
				{2, 3, 0, 0},
				{0, 0, 4, 4}
		};
		SparseDoubleMatrix2D S2 = new SparseDoubleMatrix2D(A);
		
		//expected representation edu.emory.mathcs.csparsej.tdouble.Dcs_common.Dcs
		//i (row indices, size nzmax):
		int[] expected_i = new int[]{2, 0, 2, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};//S2.dcs.i
		//m (number of rows):
		int expected_m = 4;//S2.dcs.m
		//n (number of columns) 
		int expected_n = 4;//S2.dcs.n
		//nz (# of entries in triplet matrix, -1 for compressed-col) 
		int expected_nz = -1;//S2.dcs.nz
		//nxmax (maximum number of entries)
		int expected_nzmax = 16;//S2.dcs.nzmax
		//p (column pointers (size n+1) or col indices (size nzmax))
		int[] expected_p = new int[]{0, 2, 3, 5, 7};//S2.dcs.p
		//x (numerical values, size nzmax)
		double[] expected_x = new double[]{2.0, 1.0, 3.0, 4.0, 2.0, 4.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};//S2.dcs.x
		
		SparseDoubleMatrix2D AMatrix = new SparseDoubleMatrix2D(A);
		Dcs_common.Dcs dcs = ColtUtils.matrixToDcs(AMatrix);
		
		//assertions
		assertEquals(expected_i.length, dcs.i.length);
		for(int i=0; i<expected_i.length; i++){
			assertEquals(expected_i[i], dcs.i[i]);
		}
		assertEquals(expected_m, dcs.m);
		assertEquals(expected_n, dcs.n);
		assertEquals(expected_nz, dcs.nz);
		assertEquals(expected_nzmax, dcs.nzmax);
		assertEquals(expected_p.length, dcs.p.length);
		for(int i=0; i<expected_p.length; i++){
			assertEquals(expected_p[i], dcs.p[i]);
		}
		assertEquals(expected_x.length, dcs.x.length);
		for(int i=0; i<expected_x.length; i++){
			assertEquals(expected_x[i], dcs.x[i]);
		}
	}
	
	public void testDcsToMatrix() {
		logger.debug("testDcsToMatrix");
		
		Dcs_common.Dcs dcs = new Dcs_common.Dcs();
		dcs.i = new int[]{2, 0, 2, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
		dcs.m = 4;
		dcs.n = 4;
		dcs.nz = -1;
		dcs.nzmax = 16;
		dcs.p = new int[]{0, 2, 3, 5, 7};
		dcs.x = new double[]{2.0, 1.0, 3.0, 4.0, 2.0, 4.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
		
		double[][] expectedA = new double[][]{
				{1, 0, 0, 2},
				{0, 0, 2, 0},
				{2, 3, 0, 0},
				{0, 0, 4, 4}
		};
		DoubleMatrix2D S2 = ColtUtils.dcsToMatrix(dcs);
		
		//assertions
		assertEquals(4, S2.rows());
		assertEquals(4, S2.columns());
		for(int i=0; i<4; i++){
			for(int j=0; j<4; j++){
				assertEquals(expectedA[i][j], S2.getQuick(i, j));
			}
		}
	}
	
	public void testGetConditionNumberRanges() throws Exception {
		logger.debug("testConditionNumberRanges");
		double[][] A = new double[][] {
				{1., 0, 0}, 
				{0., 2., 0}, 
				{0., 0., 3.}};
		double kExpected2 = 3;
		double kExpected00 = 3;
		RealMatrix AMatrix = new Array2DRowRealMatrix(A);
		double[] cn_2 = super.getConditionNumberRange(AMatrix, 2);
		double[] cn_00 = super.getConditionNumberRange(AMatrix, Integer.MAX_VALUE);
		logger.debug("cn_2 : " + ArrayUtils.toString(cn_2));
		logger.debug("cn_00: " + ArrayUtils.toString(cn_00));
		assertTrue(kExpected2 >= cn_2[0]);
		assertTrue(kExpected00 >= cn_00[0]);
	}
	
	public void testSymmPermutation1() throws Exception {
		logger.debug("testSymmPermutation1");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		Algebra ALG = Algebra.DEFAULT;
		double[][] A = new double[][] {
				{0.0, 0.1, 0.2}, 
				{1.0, 1.1, 1.2}, 
				{2.0, 2.1, 2.2}};
		double[][] P = new double[][] {
				{1, 0, 0}, 
				{0, 0, 1}, 
				{0, 1, 0}};
		DoubleMatrix2D AMatrix = F2.make(A);
		DoubleMatrix2D PMatrix = F2.make(P);
		DoubleMatrix2D APermuted = ColtUtils.symmPermutation(AMatrix, 1, 2);
		logger.debug("APermuted: " + ArrayUtils.toString(APermuted.toArray()));
		DoubleMatrix2D E = ALG.mult(PMatrix, ALG.mult(AMatrix, ALG.transpose(PMatrix)));
		double norm = MatrixUtils.createRealMatrix(E.toArray()).subtract(MatrixUtils.createRealMatrix(APermuted.toArray())).getNorm();
		logger.debug("norm: " + norm);
		assertEquals(0, norm, 1.e-15);
	}
	
	public void testSymmPermutation2() throws Exception {
		logger.debug("testSymmPermutation2");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		Algebra ALG = Algebra.DEFAULT;
		double[][] A = new double[][] {
				{0, 1}, 
				{1, 0}};
		double[][] P = new double[][] {
				{0, 1}, 
				{1, 0}};
		DoubleMatrix2D AMatrix = F2.make(A);
		DoubleMatrix2D PMatrix = F2.make(P);
		DoubleMatrix2D APermuted = ColtUtils.symmPermutation(AMatrix, 0, 1);
		logger.debug("APermuted: " + ArrayUtils.toString(APermuted.toArray()));
		DoubleMatrix2D E = ALG.mult(PMatrix, ALG.mult(AMatrix, ALG.transpose(PMatrix)));
		double norm = MatrixUtils.createRealMatrix(E.toArray()).subtract(MatrixUtils.createRealMatrix(APermuted.toArray())).getNorm();
		logger.debug("norm: " + norm);
		assertEquals(0, norm, 1.e-15);
	}
	
	public void testCalculateDeterminant() throws Exception {
		logger.debug("testCalculateDeterminant");
		double[][] A = new double[][] { 
				{ 1, 0, 0 }, 
				{ 0, 1, 0 }, 
				{ 0, 0, 1 } };
		DoubleMatrix2D AMatrix = DoubleFactory2D.dense.make(A);
		double det = ColtUtils.calculateDeterminant(AMatrix);
		assertEquals(1., det);
	}
	
	public void testCalculateDeterminant2() throws Exception {
		logger.debug("testCalculateDeterminant2");
		double[][] A = new double[][] { 
				{ 1, 3, 4 }, 
				{ 4, 0, 8 }, 
				{ 5, 7, 6 } };
		DoubleMatrix2D AMatrix = DoubleFactory2D.dense.make(A);
		double det = ColtUtils.calculateDeterminant(AMatrix);
		assertEquals(104., det);
	}
	
	public void testCalculateDeterminantBig() throws Exception {
		logger.debug("testCalculateDeterminantBig");
		DoubleMatrix2D A = ColtUtils.randomValuesSparseMatrix(40, 40, -1, 1, 0.9,	12345L);
		logger.debug("A: " + ArrayUtils.toString(A.toArray()));
		double det = ColtUtils.calculateDeterminant(A);
		assertEquals(0., det, 1.e-6);
	}
	
	public void testRandomValuesSparseMatrix() {
		logger.debug("testRandomValuesSparseMatrix");
		int m = 3;
		int dim = m * m;
		double sparsityIndex = 0.9;
		DoubleMatrix2D sMatrix = ColtUtils.randomValuesSparseMatrix(m, m, -10, 10, sparsityIndex, 12345L);
		logger.debug("sMatrix: " + ArrayUtils.toString(sMatrix.toArray()));
		logger.debug("cardinality: " + sMatrix.cardinality());
		int nz = dim - sMatrix.cardinality();
		double actualSparsityIndex = new Double(nz) / dim;
		logger.debug("actual sparsity index: " + 100 * actualSparsityIndex + " %");
		assertTrue(Math.abs(sparsityIndex - actualSparsityIndex)/sparsityIndex < 0.1);
	}
	
	public void testScalarMult1D() {
		logger.debug("testScalarMult1D");
		double[] u = new double[] { 1, 2, 3 };
		DoubleMatrix1D v1 = DoubleFactory1D.dense.make(u);
		double c = 2;
		DoubleMatrix1D act = ColtUtils.scalarMult(v1, c);
		for (int i = 0; i < u.length; i++) {
			assertEquals(u[i], v1.getQuick(i));//v1 unchanged
			assertEquals(u[i] * c, act.getQuick(i));
		}
	}
}
