/*
 * Copyright 2025-present Solver4J
 *
 * This work is licensed under the Creative Commons Attribution-NoDerivatives 4.0 
 * International License. To view a copy of this license, visit 
 *
 *        http://creativecommons.org/licenses/by-nd/4.0/ 
 *
 * or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
 */
package com.solver4j.linear.factorization;

import java.io.File;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.solver4j.util.ColtUtils;
import com.solver4j.util.Solver4JBaseTest;
import com.solver4j.util.Utils;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;

/**
 * @author <a href="mailto:orion.waverly@gmail.com">Orion Waverly</a>
 */
public class LDLTFactorizationTest extends Solver4JBaseTest {
	private Logger logger = LoggerFactory.getLogger(this.getClass().getName());

	public void testFactorize1() throws Exception {
		logger.debug("testFactorize1");
		double[][] QData = new double[][] { 
				{ 1, .12, .13, .14, .15 },
				{ .12, 2, .23, .24, .25 }, 
				{ .13, .23, 3, 0, 0 },
				{ .14, .24, 0, 4, 0 }, 
				{ .15, .25, 0, 0, 5 } };
		RealMatrix Q = MatrixUtils.createRealMatrix(QData);

		LDLTFactorization myc = new LDLTFactorization(new DenseDoubleMatrix2D(QData));
		myc.factorize();
		RealMatrix L = new Array2DRowRealMatrix(myc.getL().toArray());
		RealMatrix D = new Array2DRowRealMatrix(myc.getD().toArray());
		RealMatrix LT = new Array2DRowRealMatrix(myc.getLT().toArray());
		logger.debug("L: " + L);
		logger.debug("D: " + D);
		logger.debug("LT: " + LT);
		logger.debug("L.D.LT: " + L.multiply(D.multiply(LT)));

		// check Q = L.D.LT
		double norm = L.multiply(D).multiply(LT).subtract(Q).getNorm();
		logger.debug("norm: " + norm);
		assertTrue(norm < Utils.getDoubleMachineEpsilon());
	}
	
	public void testFactorization2() throws Exception {
		logger.debug("testFactorization2");
		DoubleMatrix2D P1 = DoubleFactory2D.dense.make(new double[][] {
				{ 8.185301256666552E9, 1.5977225251367908E9 },
				{ 1.5977225251367908E9, 3.118660129093004E8 } });
		LDLTFactorization cFact1 = new LDLTFactorization(P1, new Matrix1NormRescaler(), true);
		cFact1.factorize(true);
		
		//solve Q.x = b
		DoubleMatrix1D b = ColtUtils.randomValuesVector(P1.rows(), -1, 1, 12345L);
		DoubleMatrix1D x = cFact1.solve(b);
		double scaledResidualx_1 = ColtUtils.calculateScaledResidual(P1, x, b);
		logger.debug("scaledResidualx_1: " + scaledResidualx_1);
		assertTrue(scaledResidualx_1 < Utils.getDoubleMachineEpsilon());
		
	}

	public void testFactorizeNotPositive() throws Exception {
		logger.debug("testFactorizeNotPositive");
		double[][] QData = new double[][] { 
				{ 1, 0 }, 
				{ 0, -1 } };
		RealMatrix Q = MatrixUtils.createRealMatrix(QData);

		LDLTFactorization myc = new LDLTFactorization(new DenseDoubleMatrix2D(QData));
		myc.factorize();
		RealMatrix L = new Array2DRowRealMatrix(myc.getL().toArray());
		RealMatrix D = new Array2DRowRealMatrix(myc.getD().toArray());
		RealMatrix LT = new Array2DRowRealMatrix(myc.getLT().toArray());
		logger.debug("L: " + L);
		logger.debug("D: " + D);
		logger.debug("LT: " + LT);
		logger.debug("L.D.LT: " + L.multiply(D.multiply(LT)));
		
		// check Q = L.D.LT
		double norm = L.multiply(D).multiply(LT).subtract(Q).getNorm();
		logger.debug("norm: " + norm);
		assertTrue(norm < Utils.getDoubleMachineEpsilon());
	}

	public void testFactorizeSingular() throws Exception {
		logger.debug("testFactorizeSingular");
		double[][] QData = new double[][] { 
				{ 1, 0, 1 }, 
				{ 0, -1, 0 },
				{ 1, 0, 1 } };
		RealMatrix Q = MatrixUtils.createRealMatrix(QData);

		try{
			LDLTFactorization myc = new LDLTFactorization(new DenseDoubleMatrix2D(QData));
			myc.factorize();
			
			fail();//the factorization must detect the singularity
		}catch(Exception e){
			assertTrue(true);///OK
		}		
	}
	
	/**
	 * The matrix7 has a regular Cholesky factorization (as given by Mathematica) 
	 * so Solver4J must be able to factorize it
	 */
	public void testScale6() throws Exception {
		logger.debug("testScale6");
		DoubleFactory2D F2 = DoubleFactory2D.sparse;
		DoubleFactory1D F1 = DoubleFactory1D.sparse;
		Algebra ALG = Algebra.DEFAULT;
		
		String matrixId = "6";
		double[][] A = super.loadDoubleMatrixFromFile("factorization" + File.separator + "matrix" + matrixId + ".csv", ",".charAt(0));
		SparseDoubleMatrix2D AMatrix = (SparseDoubleMatrix2D) F2.make(A);
		int dim = AMatrix.rows();
		
		LDLTFactorization myc;
		try{
			myc = new LDLTFactorization(new DenseDoubleMatrix2D(A));
			myc.factorize();
		}catch(Exception e){
			logger.debug("numeric problem, try to rescale the matrix");
			MatrixRescaler rescaler = new Matrix1NormRescaler();
			DoubleMatrix1D Uv = rescaler.getMatrixScalingFactorsSymm(AMatrix);
			DoubleMatrix2D U = F2.diagonal(Uv);
			
			assertTrue(rescaler.checkScaling(ColtUtils.fillSubdiagonalSymmetricMatrix(AMatrix), Uv, Uv));
			
			DoubleMatrix2D AScaled = ColtUtils.diagonalMatrixMult(Uv, AMatrix, Uv);
			myc = new LDLTFactorization(AScaled);
			myc.factorize();
			
			//NOTE: with scaling, we must solve U.A.U.z = U.b, after that we have x = U.z
			
			//solve Q.x = b
			DoubleMatrix1D b = ColtUtils.randomValuesVector(dim, -1, 1, 12345L);
			DoubleMatrix1D x = myc.solve(ALG.mult(U, b));
			double scaledResidualx = ColtUtils.calculateScaledResidual(AMatrix, ALG.mult(U, x), b);
			logger.debug("scaledResidualx: " + scaledResidualx);
			assertTrue(scaledResidualx < 1.e-15);
		}
	}
}
