using System;
using System.Diagnostics;
using Passer.LinearAlgebra;
#if UNITY_5_3_OR_NEWER
using Vector3Float = UnityEngine.Vector3;
using Vector2Float = UnityEngine.Vector2;
#endif
using Quaternion = UnityEngine.Quaternion;

public readonly struct Slice {
    public uint start { get; }
    public uint stop { get; }
    public Slice(uint start, uint stop) {
        this.start = start;
        this.stop = stop;
    }
}

public class Matrix2 {
    public float[,] data { get; }

    public uint nRows => (uint)data.GetLength(0);
    public uint nCols => (uint)data.GetLength(1);

    public Matrix2(uint nRows, uint nCols) {
        this.data = new float[nRows, nCols];
    }
    public Matrix2(float[,] data) {
        this.data = data;
    }

    public Matrix2 Clone() {
        float[,] data = new float[this.nRows, nCols];
        for (int rowIx = 0; rowIx < this.nRows; rowIx++) {
            for (int colIx = 0; colIx < this.nCols; colIx++)
                data[rowIx, colIx] = this.data[rowIx, colIx];
        }
        return new Matrix2(data);
    }

    public static Matrix2 Zero(uint nRows, uint nCols) {
        return new Matrix2(nRows, nCols);
    }

    public static Matrix2 FromVector3(Vector3Float v) {
        float[,] result = new float[3, 1];
        result[0, 0] = v.x;
        result[1, 0] = v.y;
        result[2, 0] = v.z;
        return new Matrix2(result);
    }

    public static Matrix2 Identity(uint size) {
        return Diagonal(1, size);
    }
    public static Matrix2 Identity(uint nRows, uint nCols) {
        Matrix2 m = Zero(nRows, nCols);
        m.FillDiagonal(1);
        return m;
    }

    public static Matrix2 Diagonal(Matrix1 v) {
        float[,] resultData = new float[v.size, v.size];
        for (int ix = 0; ix < v.size; ix++)
            resultData[ix, ix] = v.data[ix];
        return new Matrix2(resultData);
    }
    public static Matrix2 Diagonal(float f, uint size) {
        float[,] resultData = new float[size, size];
        for (int ix = 0; ix < size; ix++)
            resultData[ix, ix] = f;
        return new Matrix2(resultData);
    }
    public void FillDiagonal(Matrix1 v) {
        uint n = Math.Min(Math.Min(this.nRows, this.nCols), v.size);
        for (int ix = 0; ix < n; ix++)
            this.data[ix, ix] = v.data[ix];
    }
    public void FillDiagonal(float f) {
        uint n = Math.Min(this.nRows, this.nCols);
        for (int ix = 0; ix < n; ix++)
            this.data[ix, ix] = f;
    }

    public static Matrix2 SkewMatrix(Vector3Float v) {
        float[,] result = new float[3, 3] {
            {0, -v.z, v.y},
            {v.z, 0, -v.x},
            {-v.y, v.x, 0}
        };
        return new Matrix2(result);
    }

    public Vector3Float GetRow3(int rowIx) {
        uint cols = this.nCols;
        Vector3Float row = new() {
            x = this.data[rowIx, 0],
            y = this.data[rowIx, 1],
            z = this.data[rowIx, 2]
        };
        return row;
    }
    public void SetRow(int rowIx, Matrix1 v) {
        for (uint ix = 0; ix < v.size; ix++)
            this.data[rowIx, ix] = v.data[ix];
    }
    public void SetRow3(int rowIx, Vector3Float v) {
        this.data[rowIx, 0] = v.x;
        this.data[rowIx, 1] = v.y;
        this.data[rowIx, 2] = v.z;
    }

    public Matrix1 GetColumn(int colIx) {
        float[] column = new float[this.nRows];
        for (int i = 0; i < this.nRows; i++) {
            column[i] = this.data[i, colIx];
        }
        return new Matrix1(column);
    }

    public static bool AllClose(Matrix2 A, Matrix2 B, float atol = 1e-08f) {
        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++) {
                float d = MathF.Abs(A.data[i, j] - B.data[i, j]);
                if (d > atol)
                    return false;
            }
        }
        return true;
    }

    public Matrix2 Transpose() {
        float[,] resultData = new float[this.nCols, this.nRows];
        for (uint rowIx = 0; rowIx < this.nRows; rowIx++) {
            for (uint colIx = 0; colIx < this.nCols; colIx++)
                resultData[colIx, rowIx] = this.data[rowIx, colIx];
        }
        return new Matrix2(resultData);
        // double checked code
    }
    public Matrix2 transposed {
        get => Transpose();
    }

    public static Matrix2 operator -(Matrix2 m) {
        float[,] result = new float[m.nRows, m.nCols];

        for (int i = 0; i < m.nRows; i++) {
            for (int j = 0; j < m.nCols; j++)
                result[i, j] = -m.data[i, j];
        }
        return new Matrix2(result);
    }

    public static Matrix2 operator -(Matrix2 A, Matrix2 B) {
        if (A.nRows != B.nRows || A.nCols != B.nCols)
            throw new System.ArgumentException("Size of A must match size of B.");

        float[,] result = new float[A.nRows, B.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++)
                result[i, j] = A.data[i, j] - B.data[i, j];
        }
        return new Matrix2(result);
    }

    public static Matrix2 operator +(Matrix2 A, Matrix2 B) {
        if (A.nRows != B.nRows || A.nCols != B.nCols)
            throw new System.ArgumentException("Size of A must match size of B.");

        float[,] result = new float[A.nRows, B.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++)
                result[i, j] = A.data[i, j] + B.data[i, j];
        }
        return new Matrix2(result);
    }

    public static Matrix2 operator *(Matrix2 A, Matrix2 B) {
        if (A.nCols != B.nRows)
            throw new System.ArgumentException("Number of columns in A must match number of rows in B.");

        float[,] result = new float[A.nRows, B.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < B.nCols; j++) {
                float sum = 0.0f;
                for (int k = 0; k < A.nCols; k++)
                    sum += A.data[i, k] * B.data[k, j];

                result[i, j] = sum;
            }
        }

        return new Matrix2(result);
        // double checked code
    }

    public static Matrix1 operator *(Matrix2 A, Matrix1 v) {
        float[] result = new float[A.nRows];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++) {
                result[i] += A.data[i, j] * v.data[j];
            }
        }

        return new Matrix1(result);
    }

    public static Vector3Float operator *(Matrix2 A, Vector3Float v) {
        return new Vector3Float(
            A.data[0, 0] * v.x + A.data[0, 1] * v.y + A.data[0, 2] * v.z,
            A.data[1, 0] * v.x + A.data[1, 1] * v.y + A.data[1, 2] * v.z,
            A.data[2, 0] * v.x + A.data[2, 1] * v.y + A.data[2, 2] * v.z
        );
    }

    public static Matrix2 operator *(Matrix2 A, float s) {
        float[,] result = new float[A.nRows, A.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++)
                result[i, j] = A.data[i, j] * s;
        }

        return new Matrix2(result);
    }
    public static Matrix2 operator *(float s, Matrix2 A) {
        return A * s;
    }

    public static Matrix2 operator /(Matrix2 A, float s) {
        float[,] result = new float[A.nRows, A.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++)
                result[i, j] = A.data[i, j] / s;
        }

        return new Matrix2(result);
    }
    public static Matrix2 operator /(float s, Matrix2 A) {
        float[,] result = new float[A.nRows, A.nCols];

        for (int i = 0; i < A.nRows; i++) {
            for (int j = 0; j < A.nCols; j++)
                result[i, j] = s / A.data[i, j];
        }

        return new Matrix2(result);
    }

    public Matrix2 Slice(Slice slice) {
        return Slice(slice.start, slice.stop);
    }
    public Matrix2 Slice(uint from, uint to) {
        if (from < 0 || to >= this.nRows)
            throw new System.ArgumentException("Slice index out of range.");

        float[,] result = new float[to - from, this.nCols];
        int resultRowIx = 0;
        for (uint rowIx = from; rowIx < to; rowIx++) {
            for (int colIx = 0; colIx < this.nCols; colIx++) {
                result[resultRowIx, colIx] = this.data[rowIx, colIx];
            }
            resultRowIx++;
        }

        return new Matrix2(result);
    }
    public Matrix2 Slice(Slice rowRange, Slice colRange) {
        return Slice((rowRange.start, rowRange.stop), (colRange.start, colRange.stop));
    }

    public Matrix2 Slice((uint start, uint stop) rowRange, (uint start, uint stop) colRange) {
        float[,] result = new float[rowRange.stop - rowRange.start, colRange.stop - colRange.start];

        uint resultRowIx = 0;
        uint resultColIx = 0;
        for (uint i = rowRange.start; i < rowRange.stop; i++) {
            for (uint j = colRange.start; j < colRange.stop; j++)
                result[resultRowIx, resultColIx] = this.data[i, j];
        }
        return new Matrix2(result);
    }

    public void UpdateSlice(Slice slice, Matrix2 m) {
        int mRowIx = 0;
        for (uint rowIx = slice.start; rowIx < slice.stop; rowIx++, mRowIx++) {
            for (int colIx = 0; colIx < this.nCols; colIx++)
                this.data[rowIx, colIx] = m.data[mRowIx, colIx];
        }
    }
    public void UpdateSlice(Slice rowRange, Slice colRange, Matrix2 m) {
        UpdateSlice((rowRange.start, rowRange.stop), (colRange.start, colRange.stop), m);
    }
    public void UpdateSlice((uint start, uint stop) rowRange, (uint start, uint stop) colRange, Matrix2 m) {
        for (uint i = rowRange.start; i < rowRange.stop; i++) {
            for (uint j = colRange.start; j < colRange.stop; j++)
                this.data[i, j] = m.data[i - rowRange.start, j - colRange.start];
        }
    }

    public Matrix2 Inverse() {
        Matrix2 A = this;
        // unchecked
        uint n = A.nRows;

        // Create an identity matrix of the same size as the original matrix
        float[,] augmentedMatrix = new float[n, 2 * n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                augmentedMatrix[i, j] = A.data[i, j];
                augmentedMatrix[i, j + n] = (i == j) ? 1 : 0;  // Identity matrix
            }
        }

        // Perform Gaussian elimination
        for (int i = 0; i < n; i++) {
            // Find the pivot row
            float pivot = augmentedMatrix[i, i];
            if (Math.Abs(pivot) < 1e-10) // Check for singular matrix
                throw new InvalidOperationException("Matrix is singular and cannot be inverted.");

            // Normalize the pivot row
            for (int j = 0; j < 2 * n; j++)
                augmentedMatrix[i, j] /= pivot;

            // Eliminate the column below the pivot
            for (int j = i + 1; j < n; j++) {
                float factor = augmentedMatrix[j, i];
                for (int k = 0; k < 2 * n; k++)
                    augmentedMatrix[j, k] -= factor * augmentedMatrix[i, k];
            }
        }

        // Back substitution
        for (uint i = n - 1; i >= 0; i--) {
            // Eliminate the column above the pivot
            for (uint j = i - 1; j >= 0; j--) {
                float factor = augmentedMatrix[j, i];
                for (int k = 0; k < 2 * n; k++)
                    augmentedMatrix[j, k] -= factor * augmentedMatrix[i, k];
            }
        }

        // Extract the inverse matrix from the augmented matrix
        float[,] inverse = new float[n, n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++)
                inverse[i, j] = augmentedMatrix[i, j + n];
        }

        return new Matrix2(inverse);
    }

    public float Determinant() {
        uint n = this.nRows;
        if (n != this.nCols)
            throw new System.ArgumentException("Matrix must be square.");

        if (n == 1)
            return this.data[0, 0]; // Base case for 1x1 matrix

        if (n == 2) // Base case for 2x2 matrix
            return this.data[0, 0] * this.data[1, 1] - this.data[0, 1] * this.data[1, 0];

        float det = 0;
        for (int col = 0; col < n; col++)
            det += (col % 2 == 0 ? 1 : -1) * this.data[0, col] * this.Minor(0, col).Determinant();

        return det;
    }

    // Helper function to compute the minor of a matrix
    private Matrix2 Minor(int rowToRemove, int colToRemove) {
        uint n = this.nRows;
        float[,] minor = new float[n - 1, n - 1];

        int r = 0, c = 0;
        for (int i = 0; i < n; i++) {
            if (i == rowToRemove) continue;

            c = 0;
            for (int j = 0; j < n; j++) {
                if (j == colToRemove) continue;

                minor[r, c] = this.data[i, j];
                c++;
            }
            r++;
        }

        return new Matrix2(minor);
    }
}

public class Matrix1 {
    public float[] data { get; }

    public uint size => (uint)data.GetLength(0);

    public Matrix1(uint size) {
        this.data = new float[size];
    }

    public Matrix1(float[] data) {
        this.data = data;
    }

    public static Matrix1 Zero(uint size) {
        return new Matrix1(size);
    }

    public static Matrix1 FromVector2(Vector2Float v) {
        float[] result = new float[2];
        result[0] = v.x;
        result[1] = v.y;
        return new Matrix1(result);
    }

    public static Matrix1 FromVector3(Vector3Float v) {
        float[] result = new float[3];
        result[0] = v.x;
        result[1] = v.y;
        result[2] = v.z;
        return new Matrix1(result);
    }

    public static Matrix1 FromQuaternion(Quaternion q) {
        float[] result = new float[4];
        result[0] = q.x;
        result[1] = q.y;
        result[2] = q.z;
        result[3] = q.w;
        return new Matrix1(result);
    }

    public Vector2Float vector2 {
        get {
            if (this.size != 2)
                throw new System.ArgumentException("Matrix1 must be of size 2");
            return new Vector2Float(this.data[0], this.data[1]);
        }
    }
    public Vector3Float vector3 {
        get {
            if (this.size != 3)
                throw new System.ArgumentException("Matrix1 must be of size 3");
            return new Vector3Float(this.data[0], this.data[1], this.data[2]);
        }
    }
    public Quaternion quaternion {
        get {
            if (this.size != 4)
                throw new System.ArgumentException("Matrix1 must be of size 4");
            return new Quaternion(this.data[0], this.data[1], this.data[2], this.data[3]);
        }
    }

    public Matrix1 Clone() {
        float[] data = new float[this.size];
        for (int rowIx = 0; rowIx < this.size; rowIx++)
            data[rowIx] = this.data[rowIx];
        return new Matrix1(data);
    }


    public float magnitude {
        get {
            float sum = 0;
            foreach (var elm in data)
                sum += elm;
            return sum / data.Length;
        }
    }
    public static Matrix1 operator +(Matrix1 A, Matrix1 B) {
        if (A.size != B.size)
            throw new System.ArgumentException("Size of A must match size of B.");

        float[] result = new float[A.size];

        for (int i = 0; i < A.size; i++) {
            result[i] = A.data[i] + B.data[i];
        }
        return new Matrix1(result);
    }

    public Matrix2 Transpose() {
        float[,] r = new float[1, this.size];
        for (uint colIx = 0; colIx < this.size; colIx++)
            r[1, colIx] = this.data[colIx];

        return new Matrix2(r);
    }

    public static float Dot(Matrix1 a, Matrix1 b) {
        if (a.size != b.size)
            throw new System.ArgumentException("Vectors must be of the same length.");

        float result = 0.0f;
        for (int i = 0; i < a.size; i++) {
            result += a.data[i] * b.data[i];
        }
        return result;
    }

    public static Matrix1 operator *(Matrix1 A, float f) {
        float[] result = new float[A.size];

        for (int i = 0; i < A.size; i++)
            result[i] += A.data[i] * f;

        return new Matrix1(result);
    }
    public static Matrix1 operator *(float f, Matrix1 A) {
        return A * f;
    }

    public Matrix1 Slice(Slice range) {
        return Slice(range.start, range.stop);
    }
    public Matrix1 Slice(uint from, uint to) {
        if (from < 0 || to >= this.size)
            throw new System.ArgumentException("Slice index out of range.");

        float[] result = new float[to - from];
        int resultIx = 0;
        for (uint ix = from; ix < to; ix++)
            result[resultIx++] = this.data[ix];

        return new Matrix1(result);
    }
    public void UpdateSlice(Slice slice, Matrix1 v) {
        int vIx = 0;
        for (uint ix = slice.start; ix < slice.stop; ix++, vIx++)
            this.data[ix] = v.data[vIx];
    }
}