/*
 * Copyright 2025-present Solver4J
 *
 * This work is licensed under the Creative Commons Attribution-NoDerivatives 4.0 
 * International License. To view a copy of this license, visit 
 *
 *        http://creativecommons.org/licenses/by-nd/4.0/ 
 *
 * or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
 */
package com.solver4j.linear.factorization;

import java.io.File;

import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.solver4j.util.ColtUtils;
import com.solver4j.util.Solver4JBaseTest;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.SingularValueDecomposition;

public class MatrixLogSumRescalerTest extends Solver4JBaseTest {

	private Logger logger = LoggerFactory.getLogger(this.getClass().getName());
	
	public void testSimpleScalingNoSymm() throws Exception {
		logger.debug("testSimpleScalingNoSymm");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		final double[][] A = new double[][]{
				{1, 0, 0},
				{0, 0, 2},
				{2, 3, 0},
				{0, 0, 4}
		};
		DoubleMatrix2D AMatrix = F2.make(A);
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D[] UV = rescaler.getMatrixScalingFactors(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(UV[0], AMatrix, UV[1]);
		logger.debug("AScaled : " + ArrayUtils.toString(AScaled.toArray()));
		double cn0 = new SingularValueDecomposition(AMatrix).cond();
		double cn1 = new SingularValueDecomposition(AScaled).cond();
		double norm0 = ALG.normInfinity(AMatrix);
		double norm1 = ALG.normInfinity(AScaled);
		logger.debug("U : " + ArrayUtils.toString(UV[0].toArray()));
		logger.debug("V : " + ArrayUtils.toString(UV[1].toArray()));
		logger.debug("cn0: " + cn0);
		logger.debug("cn1: " + cn1);
		logger.debug("norm0: " + norm0);
		logger.debug("norm1: " + norm1);
		assertTrue(rescaler.checkScaling(AMatrix, UV[0], UV[1]));
		assertFalse(cn1 > cn0);//not guaranteed by the rescaling
		assertFalse(norm1 > norm0);//not guaranteed by the rescaling
	}
	
	public void testSimpleScalingSymm() throws Exception {
		logger.debug("testSimpleScalingSymm");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		Algebra ALG = Algebra.DEFAULT;
		double[][] A = new double[][] {
				{1., 0.5e7, 0}, 
				{0.5e7, 2., 0}, 
				{0., 0., 3.e-9}};
		DoubleMatrix2D AMatrix = F2.make(A);
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D U = rescaler.getMatrixScalingFactorsSymm(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(U, AMatrix, U);
		double cn0 = new SingularValueDecomposition(AMatrix).cond();
		double cn1 = new SingularValueDecomposition(AScaled).cond();
		double norm0 = ALG.normInfinity(AMatrix);
		double norm1 = ALG.normInfinity(AScaled);
		logger.debug("U : " + ArrayUtils.toString(U.toArray()));
		logger.debug("AScaled : " + ArrayUtils.toString(AScaled.toArray()));
		logger.debug("cn0: " + cn0);
		logger.debug("cn1: " + cn1);
		logger.debug("norm0: " + norm0);
		logger.debug("norm1: " + norm1);
		assertTrue(rescaler.checkScaling(AMatrix, U, U));
		assertFalse(cn1 > cn0);
		assertFalse(norm1 > norm0);
	}
	
	/**
	 * Test of the matrix in Gajulapalli example 2.1.
	 * It is a Pathological Square Matrix.
	 * @see Gajulapalli, Lasdon "Scaling Sparse Matrices for Optimization Algorithms"
	 */
	public void testPathologicalScalingNoSymm() throws Exception {
		logger.debug("testPathologicalScalingNoSymm");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		double[][] A = new double[][] { 
				{ 1.e0, 1.e10, 1.e20 }, 
				{ 1.e10, 1.e30, 1.e50 }, 
				{ 1.e20, 1.e40, 1.e80 } };
		DoubleMatrix2D AMatrix = F2.make(A);
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D[] UV = rescaler.getMatrixScalingFactors(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(UV[0], AMatrix, UV[1]);
		double cn0 = new SingularValueDecomposition(AMatrix).cond();
		double cn1 = new SingularValueDecomposition(AScaled).cond();
		double norm0 = ALG.normInfinity(AMatrix);
		double norm1 = ALG.normInfinity(AScaled);
		logger.debug("U : " + ArrayUtils.toString(UV[0].toArray()));
		logger.debug("V : " + ArrayUtils.toString(UV[1].toArray()));
		logger.debug("AScaled : " + ArrayUtils.toString(AScaled.toArray()));
		logger.debug("cn0: " + cn0);
		logger.debug("cn1: " + cn1);
		logger.debug("norm0: " + norm0);
		logger.debug("norm1: " + norm1);
		assertTrue(rescaler.checkScaling(AMatrix, UV[0], UV[1]));
		assertFalse(cn1 > cn0);//not guaranteed by the rescaling
		assertFalse(norm1 > norm0);//not guaranteed by the rescaling
	}
	
	/**
	 * Test of the matrix in Gajulapalli example 3.1.
	 * It is a Pathological Square Matrix.
	 * @see Gajulapalli, Lasdon "Scaling Sparse Matrices for Optimization Algorithms"
	 */
	public void testPathologicalScalingSymm() throws Exception {
		logger.debug("testPathologicalScalingSymm");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		double[][] A = new double[][] { 
				{ 1.e0,  1.e20, 1.e10, 1.e0  }, 
				{ 1.e20, 1.e20, 1.e0,  1.e40 }, 
				{ 1.e10, 1.e0,  1.e40, 1.e50 },
				{ 1.e0 , 1.e40, 1.e50, 1.e0 }};
		DoubleMatrix2D AMatrix = F2.make(A);
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D[] UV = rescaler.getMatrixScalingFactors(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(UV[0], AMatrix, UV[1]);
		double cn0 = new SingularValueDecomposition(AMatrix).cond();
		double cn1 = new SingularValueDecomposition(AScaled).cond();
		double norm0 = ALG.normInfinity(AMatrix);
		double norm1 = ALG.normInfinity(AScaled);
		logger.debug("U : " + ArrayUtils.toString(UV[0].toArray()));
		logger.debug("V : " + ArrayUtils.toString(UV[1].toArray()));
		logger.debug("AScaled : " + ArrayUtils.toString(AScaled.toArray()));
		logger.debug("cn0: " + cn0);
		logger.debug("cn1: " + cn1);
		logger.debug("norm0: " + norm0);
		logger.debug("norm1: " + norm1);
		assertTrue(rescaler.checkScaling(AMatrix, UV[0], UV[1]));
		assertFalse(cn1 > cn0);//not guaranteed by the rescaling
		assertFalse(norm1 > norm0);//not guaranteed by the rescaling
	}
	
	/**
	 * Test the matrix norm before and after scaling.
	 * Note that scaling is not guaranteed to give a better condition number.
	 * The test shows some issue with matrix norm, in that this type of scaling
	 * in not effective in the norm with this matrix.
	 */
	public void testMatrixNormScaling7() throws Exception {
		logger.debug("testMatrixNormScaling7");
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		Algebra ALG = Algebra.DEFAULT;
		
		String matrixId = "7";
		double[][] A = super.loadDoubleMatrixFromFile("factorization" + File.separator + "matrix" + matrixId + ".csv", ",".charAt(0));
		final DoubleMatrix2D AMatrix = F2.make(A);
		
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		final DoubleMatrix1D U = rescaler.getMatrixScalingFactorsSymm(AMatrix);
		final DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(U, AMatrix, U);
		//logger.debug("AScaled : " + ArrayUtils.toString(AScaled.toArray()));
		
		double norm0 = ALG.normInfinity(AMatrix);
		double norm1 = ALG.normInfinity(AScaled);
		logger.debug("U : " + ArrayUtils.toString(U.toArray()));
		logger.debug("norm0: " + norm0);
		logger.debug("norm1: " + norm1);
		
		assertTrue(rescaler.checkScaling(AMatrix, U, U));//note: this must be guaranteed
		logger.debug("better matrix norm: " + (norm1 > norm0));
		//assertFalse(norm1 > norm0);//note: this is not guaranteed		
	}
	
	/**
	 * Test the rescaling of a is diagonal with some element < 1.e^16.
	 */
	public void testGetConditionNumberDiagonal() throws Exception {
		logger.debug("testGetConditionNumberDiagonal");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		
		double[] A = new double[] {1.E-17,168148.06378241107,5.333317404302006E-11,9.724301428859958E-11,4.343924031677448E-10,53042.618161481514,1.2550281021203455E-12,55714.086057404944,16564.267914476874,1.6265469281243343E-12,7.228925943265697E-11,19486.564364392565,315531.47099006834,236523.83171379057,202769.6735227342,2.4925304834427544E-13,2.7996276724404553E-13,2.069135405949759E-12,2530058.817281487,4.663208124742273E-15,2.5926311225234777E-12,2454865.060218241,7.564594931528804E-14,2.944935006524965E-13,7.938509176903875E-13,2546775.969599124,4.36659839706981E-15,3.772728220251383E-9,985020.987902404,971715.0611851265,1941150.6250316042,3.3787344131154E-10,2.8903135775881254E-11,1263.9864262585922,873899.9914494107,153097.08545910483,3.738245318154646E-11,1267390.1117847422,6.50494734416794E-10,3.588511203703992E-11,1231.6604599987518,3.772810869560189E-9,85338.92515278656,3.7382488244903144E-11,437165.36165859725,9.954549425029816E-11,1.8376434881340742E-9,86069.90894488744,1.2087907925307217E11,1.1990761432334067E11,1.163424797835085E11,1.1205515861349094E11,1.2004378300642543E11,8.219259112337953E8,1.1244633984805448E-11,1.1373907469271675E-12,1.9743774924311214E-12,6.301661187526759E-16,6.249382377266375E-16,8.298198098742164E-16,6.447686765999485E-16,1.742229837554675E-16,1.663041351618635E-16};
		DoubleMatrix1D b = ColtUtils.randomValuesVector(A.length, -1, 1, 12345L);
		
		DoubleMatrix2D AMatrix = F2.diagonal(F1.make(A));
		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D U = rescaler.getMatrixScalingFactorsSymm(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(U, AMatrix, U);
		//logger.debug("AScaled: " + ArrayUtils.toString(AScaled.toArray()));
		
		double cn_original = new SingularValueDecomposition(AMatrix).cond();
		double[] cn_2_original = ColtUtils.getConditionNumberRange(AMatrix, 2);
		double[] cn_00_original = ColtUtils.getConditionNumberRange(AMatrix, Integer.MAX_VALUE);
		double cn_scaled = new SingularValueDecomposition(AScaled).cond();
		double[] cn_2_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		double[] cn_00_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		logger.debug("cn_original   : " + ArrayUtils.toString(cn_original));
		logger.debug("cn_2_original : " + ArrayUtils.toString(cn_2_original));
		logger.debug("cn_00_original: " + ArrayUtils.toString(cn_00_original));
		logger.debug("cn_scaled     : " + ArrayUtils.toString(cn_scaled));
		logger.debug("cn_2_scaled   : " + ArrayUtils.toString(cn_2_scaled));
		logger.debug("cn_00_scaled  : " + ArrayUtils.toString(cn_00_scaled));
		
		assertTrue(rescaler.checkScaling(AMatrix, U, U));//NB: this MUST BE guaranteed by the scaling algorithm
		logger.debug("better matrix norm: " + (cn_scaled < cn_original));
		assertTrue(cn_scaled < cn_original);//NB: this IS NOT guaranteed by the scaling algorithm
	}
	
	/**
	 * Test the condition number before and after scaling.
	 * Note that scaling is not guaranteed to give a better condition number.
	 * The test shows some issue with condition number, in that this type of scaling
	 * in not effective in the condition number with this matrix.
	 */
	public void testGetConditionNumberFromFile7() throws Exception {
		logger.debug("testGetConditionNumberFromFile7");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		
		String matrixId = "7";
		double[][] A = super.loadDoubleMatrixFromFile("factorization" + File.separator + "matrix" + matrixId + ".csv", ",".charAt(0));
		DoubleMatrix2D AMatrix = F2.make(A);

		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D Uv = rescaler.getMatrixScalingFactorsSymm(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(Uv, AMatrix, Uv);
		//logger.debug("AScaled: " + ArrayUtils.toString(AScaled.toArray()));
		
		double cn_original = new SingularValueDecomposition(AMatrix).cond();
		double[] cn_2_original = ColtUtils.getConditionNumberRange(AMatrix, 2);
		double[] cn_00_original = ColtUtils.getConditionNumberRange(AMatrix, Integer.MAX_VALUE);
		double cn_scaled = new SingularValueDecomposition(AScaled).cond();
		double[] cn_2_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		double[] cn_00_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		logger.debug("cn_original   : " + ArrayUtils.toString(cn_original));
		logger.debug("cn_2_original : " + ArrayUtils.toString(cn_2_original));
		logger.debug("cn_00_original: " + ArrayUtils.toString(cn_00_original));
		logger.debug("cn_scaled     : " + ArrayUtils.toString(cn_scaled));
		logger.debug("cn_2_scaled   : " + ArrayUtils.toString(cn_2_scaled));
		logger.debug("cn_00_scaled  : " + ArrayUtils.toString(cn_00_scaled));

		assertTrue(rescaler.checkScaling(AMatrix, Uv, Uv));//NB: this MUST BE guaranteed by the scaling algorithm
		logger.debug("better matrix norm: " + (cn_scaled < cn_original));
	  //assertTrue(cn_scaled < cn_original);//NB: this IS NOT guaranteed by the scaling algorithm
	}
	
	/**
	 * Test the condition number before and after scaling.
	 * Note that scaling is not guaranteed to give a better condition number.
	 * The test shows some issue with condition number, in that this type of scaling
	 * in not effective in the condition number with this matrix.
	 */
	public void testGetConditionNumberFromFile13() throws Exception {
		logger.debug("testGetConditionNumberFromFile13");
		DoubleFactory2D F2 = DoubleFactory2D.dense;
		DoubleFactory1D F1 = DoubleFactory1D.dense;
		Algebra ALG = Algebra.DEFAULT;
		
		String matrixId = "13";
		double[][] A = super.loadDoubleMatrixFromFile("factorization" + File.separator + "matrix" + matrixId + ".csv");
		DoubleMatrix2D AMatrix = F2.make(A);

		MatrixRescaler rescaler = new MatrixLogSumRescaler();
		DoubleMatrix1D Uv = rescaler.getMatrixScalingFactorsSymm(AMatrix);
		DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(Uv, AMatrix, Uv);		
		//logger.debug("AScaled: " + ArrayUtils.toString(AScaled.toArray()));
		
		double cn_original = new SingularValueDecomposition(AMatrix).cond();
		double[] cn_2_original = ColtUtils.getConditionNumberRange(AMatrix, 2);
		double[] cn_00_original = ColtUtils.getConditionNumberRange(AMatrix, Integer.MAX_VALUE);
		double cn_scaled = new SingularValueDecomposition(AScaled).cond();
		double[] cn_2_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		double[] cn_00_scaled = ColtUtils.getConditionNumberRange(AScaled, Integer.MAX_VALUE);
		logger.debug("cn_original   : " + ArrayUtils.toString(cn_original));
		logger.debug("cn_2_original : " + ArrayUtils.toString(cn_2_original));
		logger.debug("cn_00_original: " + ArrayUtils.toString(cn_00_original));
		logger.debug("cn_scaled     : " + ArrayUtils.toString(cn_scaled));
		logger.debug("cn_2_scaled   : " + ArrayUtils.toString(cn_2_scaled));
		logger.debug("cn_00_scaled  : " + ArrayUtils.toString(cn_00_scaled));
		
		assertTrue(rescaler.checkScaling(AMatrix, Uv, Uv));//NB: this MUST BE guaranteed by the scaling algorithm
		logger.debug("better matrix norm: " + (cn_scaled < cn_original));
		//assertTrue(cn_scaled < cn_original);//NB: this IS NOT guaranteed by the scaling algorithm
	}
}
