/*
 * Copyright 2025-present Solver4J
 *
 * This work is licensed under the Creative Commons Attribution-NoDerivatives 4.0 
 * International License. To view a copy of this license, visit 
 *
 *        http://creativecommons.org/licenses/by-nd/4.0/ 
 *
 * or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
 */
package com.solver4j.linear.kkt;

import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.solver4j.util.Solver4JBaseTest;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.Property;
import cern.jet.math.Functions;

/**
 * @author <a href="mailto:orion.waverly@gmail.com">Orion Waverly</a>
 */
public class BasicKKTSolverTest extends Solver4JBaseTest {

	private Algebra ALG = Algebra.DEFAULT;
	private DoubleFactory1D F1 = DoubleFactory1D.dense;
	private DoubleFactory2D F2 = DoubleFactory2D.dense;
	private Property P = Property.TWELVE;
	private Logger logger = LoggerFactory.getLogger(this.getClass().getName());

	public void testSolveSimple() throws Exception {
		logger.debug("testSolveSimple");
		double[][] HMatrix = new double[][] { { 3 } };
		double[][] AMatrix = new double[][] { { 2 } };
		DoubleMatrix2D H = F2.make(HMatrix);
		DoubleMatrix2D A = F2.make(AMatrix);
		DoubleMatrix2D AT = ALG.transpose(A.copy());
		DoubleMatrix1D g = F1.make(1, -3);
		DoubleMatrix1D h = F1.make(1, 0);

		KKTSolver solver = new BasicKKTSolver();
		solver.setHMatrix(H);
		solver.setAMatrix(A);
		solver.setGVector(g);
		solver.setHVector(h);
		DoubleMatrix1D[] sol = solver.solve();
		DoubleMatrix1D v = sol[0];
		DoubleMatrix1D w = sol[1];
		logger.debug("v: " + ArrayUtils.toString(v.toArray()));
		logger.debug("w: " + ArrayUtils.toString(w.toArray()));

		DoubleMatrix1D a = ALG.mult(H, v).assign(ALG.mult(AT, w), Functions.plus).assign(g, Functions.plus);
		DoubleMatrix1D b = ALG.mult(A, v).assign(h, Functions.plus);
		logger.debug("a: " + ArrayUtils.toString(a.toArray()));
		logger.debug("b: " + ArrayUtils.toString(b.toArray()));
		for (int i = 0; i < a.size(); i++) {
			assertEquals(0, a.get(i), 1.E-14);
		}
		for (int i = 0; i < b.size(); i++) {
			assertEquals(0, b.get(i), 1.E-14);
		}
	}

	public void testSolve2() throws Exception {
		logger.debug("testSolve2");
		double[][] HMatrix = new double[][] { 
				{ 1.68, 0.34, 0.38 },
				{ 0.34, 3.09, -1.59 }, 
				{ 0.38, -1.59, 1.54 } };
		double[][] AMatrix = new double[][] { { 1, 2, 3 } };
		DoubleMatrix2D H = F2.make(HMatrix);
		DoubleMatrix2D A = F2.make(AMatrix);
		DoubleMatrix2D AT = ALG.transpose(A.copy());
		DoubleMatrix1D g = F1.make(new double[] { 2, 5, 1 });
		DoubleMatrix1D h = F1.make(new double[] { 1 });

		KKTSolver solver = new BasicKKTSolver();
		solver.setHMatrix(H);
		solver.setAMatrix(A);
		solver.setGVector(g);
		solver.setHVector(h);
		DoubleMatrix1D[] sol = solver.solve();
		DoubleMatrix1D v = sol[0];
		DoubleMatrix1D w = sol[1];
		logger.debug("v: " + ArrayUtils.toString(v.toArray()));
		logger.debug("w: " + ArrayUtils.toString(w.toArray()));

		DoubleMatrix1D a = ALG.mult(H, v).assign(ALG.mult(AT, w), Functions.plus).assign(g, Functions.plus);
		DoubleMatrix1D b = ALG.mult(A, v).assign(h, Functions.plus);
		logger.debug("a: " + ArrayUtils.toString(a.toArray()));
		logger.debug("b: " + ArrayUtils.toString(b.toArray()));
		for (int i = 0; i < a.size(); i++) {
			assertEquals(0, a.get(i), 1.E-14);
		}
		for (int i = 0; i < b.size(); i++) {
			assertEquals(0, b.get(i), 1.E-14);
		}
	}
}
