c++ – Compile-time Matrix Class

Intended as a small project to test out various C++20 features, as well as learn a little bit more about matrices and their uses, I decided to implement a relatively simple matrix class.

After implementing it, I figured out that every single part of it should be able to be done in a compile time context, and re-implemented it (and cleaned it up a little). I’ll provide the whole code here, and then highlight some parts where I’m unsure of whether there’s a better way. There’s a very barebones vector class alongside it for use with the matrix multiplication operator.

#include <array>
#include <stdexcept>
#include <functional>

namespace ctm
{
    typedef std::size_t index_t;
    typedef long double real;

    template <index_t N>
    requires (N > 0)
    struct vector
    {
        std::array<real, N> values;

        constexpr vector(const std::array<real, N>& vals)
        {
            for (index_t i = 0; i < N; ++i)
            {
                values(i) = vals(i);
            }
        };

        constexpr real dot(const vector<N>& other) const
        {
            real sum = 0.L;

            for (index_t i = 0; i < N; ++i)
            {
                sum += values(i) * other.values(i);
            }

            return sum;
        };
    };

    template <index_t M, index_t N>
    requires (M > 0 && N > 0)
    struct matrix
    {
        std::array<std::array<real, N>, M> elements = {};

        constexpr matrix() = default;

        constexpr matrix(const std::array<real, M* N>& elems)
        {
            for (index_t i = 0; i < M; ++i)
            {
                for (index_t j = 0; j < N; ++j)
                {
                    elements(i)(j) = elems(i * N + j);
                }
            }
        };

        constexpr matrix(const std::function<real(index_t, index_t)>& func)
        {
            for (index_t i = 1; i <= M; ++i)
            {
                for (index_t j = 1; j <= N; ++j)
                {
                    elements(i - 1)(j - 1) = func(i, j);
                }
            }
        };

        // accessor functions

        // 0-based indexing
        constexpr std::array<real, N>& operator ()(index_t index)
        {
            return elements(index);
        };

        // 0-based indexing
        constexpr const std::array<real, N>& operator ()(index_t index) const
        {
            return elements(index);
        };

        // 1-based indexing
        constexpr real& operator ()(index_t index_m, index_t index_n)
        {
            if (index_m == 0 || index_n == 0)
            {
                throw std::out_of_range("compound access is 1-indexed");
            }
            else if (index_m > M || index_n > N)
            {
                throw std::out_of_range("compound access out of range");
            }

            return elements(index_m - 1)(index_n - 1);
        };

        // 1-based indexing
        constexpr const real& operator ()(index_t index_m, index_t index_n) const
        {
            if (index_m == 0 || index_n == 0)
            {
                throw std::out_of_range("compound access is 1-indexed");
            }
            else if (index_m > M || index_n > N)
            {
                throw std::out_of_range("compound access out of range");
            }

            return elements(index_m - 1)(index_n - 1);
        };

        // 1-based indexing
        constexpr vector<N> row(index_t index) const
        {
            if (index == 0)
            {
                throw std::out_of_range{ "row access is 1-indexed " };
            }
            else if (index > M)
            {
                throw std::out_of_range("row access out of range");
            }

            return vector<N>{ elements(index - 1) };
        };

        // 1-based indexing
        constexpr vector<M> column(index_t index) const
        {
            if (index == 0)
            {
                throw std::out_of_range{ "column access is 1-indexed " };
            }
            else if (index > N)
            {
                throw std::out_of_range("column access out of range");
            }

            std::array<real, M> column = {};

            for (index_t j = 0; j < M; ++j)
            {
                column(j) = elements(j)(index - 1);
            }

            return vector<M>{ column };
        };

        constexpr matrix<N, M> transpose() const
        {
            matrix<N, M> result = {};

            for (index_t j = 0; j < N; ++j)
            {
                for (index_t i = 0; i < M; ++i)
                {
                    result.elements(j)(i) = elements(i)(j);
                }
            }

            return result;
        };

        consteval std::pair<index_t, index_t> size() const
        {
            return { M, N };
        };

        // arithmetic operators

        constexpr matrix& operator +=(const matrix& other)
        {
            for (index_t i = 0; i < M; ++i)
            {
                for (index_t j = 0; j < N; ++j)
                {
                    elements(i)(j) += other.elements(i)(j);
                }
            }

            return *this;
        };

        constexpr matrix& operator -=(const matrix& other)
        {
            for (index_t i = 0; i < M; ++i)
            {
                for (index_t j = 0; j < N; ++j)
                {
                    elements(i)(j) -= other.elements(i)(j);
                }
            }

            return *this;
        };

        constexpr matrix& operator *=(real scalar)
        {
            for (index_t i = 0; i < M; ++i)
            {
                for (index_t j = 0; j < N; ++j)
                {
                    elements(i)(j) *= scalar;
                }
            }

            return *this;
        };

        constexpr matrix& operator /=(real scalar)
        {
            for (index_t i = 0; i < M; ++i)
            {
                for (index_t j = 0; j < N; ++j)
                {
                    elements(i)(j) /= scalar;
                }
            }

            return *this;
        };

        constexpr friend matrix operator +(matrix left, const matrix& right)
        {
            return left += right;
        };

        constexpr friend matrix operator -(matrix left, const matrix& right)
        {
            return left -= right;
        };

        constexpr friend matrix operator *(matrix mat, real scalar)
        {
            return mat *= scalar;
        };

        constexpr friend matrix operator *(real scalar, matrix mat)
        {
            return mat *= scalar;
        };

        constexpr friend matrix operator /(matrix mat, real scalar)
        {
            return mat /= scalar;
        };

        template <index_t P>
        constexpr friend matrix<M, P> operator *(const matrix<M, N>& left, const matrix<N, P>& right)
        {
            matrix<M, P> result;

            for (index_t i = 1; i <= M; ++i)
            {
                for (index_t j = 1; j <= P; ++j)
                {
                    vector<N> left_row = left.row(i);
                    vector<N> right_column = right.column(j);
                    result(i, j) = left_row.dot(right_column);
                }
            }

            return result;
        };

        // comparison operator

        constexpr friend bool operator ==(const matrix&, const matrix&) = default;

        // the following functions are only valid for square matrices (M == N)

        // 1-based indexing
        template <index_t m = M, index_t n = N>
        requires (m == n && m > 1)
        constexpr matrix<m - 1, n - 1> first_minor(index_t i, index_t j) const
        {
            if (i == 0 || j == 0)
            {
                throw std::out_of_range("minor row and column indices are 1-indexed");
            }
            else if (i > M || j > N)
            {
                throw std::out_of_range("minor row and column indices out of range");
            }

            matrix<m - 1, n - 1> minor = {};

            index_t k = 0;
            index_t l = 0;
            for (index_t p = 0; p < m; ++p)
            {
                // skip over the row specified by i
                if (p == i - 1)
                {
                    continue;
                }

                for (index_t q = 0; q < n; ++q)
                {
                    // skip over the column specified by j
                    if (q == j - 1)
                    {
                        continue;
                    }

                    minor.elements(k)(l) = elements(p)(q);

                    ++l;
                }

                ++k;
                l = 0;
            }

            return minor;
        };

        constexpr real trace() const requires(M == N)
        {
            real sum = 0.L;

            for (index_t i = 0, j = 0; i < M; ++i, ++j)
            {
                sum += elements(i)(j);
            }

            return sum;
        };

        constexpr real determinant() const requires(M == N && M == 2)
        {
            return elements(0)(0) * elements(1)(1) - elements(0)(1) * elements(1)(0);
        };

        constexpr real determinant() const requires(M == N && M > 2)
        {
            real sum = 0.L;
            real parity = 1.L;

            for (index_t j = 0; j < N; ++j)
            {
                auto submatrix = first_minor(1, j + 1);

                sum += elements(0)(j) * parity * submatrix.determinant();
                parity = -parity;
            }

            return sum;
        };

        constexpr bool symmetric() const requires(M == N)
        {
            return *this == transpose();
        };

        constexpr bool skew_symmetric() const requires(M == N)
        {
            return *this == (transpose() *= -1.L);
        };

        // static square matrix creators

        static constexpr matrix<M, N> identity() requires(M == N)
        {
            matrix<M, N> id = {};

            for (index_t i = 0, j = 0; i < M; ++i, ++j)
            {
                id.elements(i)(j) = 1.L;
            }

            return id;
        };

        static constexpr matrix<M, N> diagonal(const std::array<real, M>& elems) requires(M == N)
        {
            matrix<M, N> diag = {};

            for (index_t i = 0, j = 0; i < M; ++i, ++j)
            {
                diag.elements(i)(j) = elems(i);
            }

            return diag;
        };
    };
};

One area that has caused trouble in both of my implementations is the constructor that takes a std::function object. Matrices can have their elements defined by the result of a function that takes the indices of the element as an input, and I wanted to replicate that. Is this the most effective way? I had to make it explicit because a matrix<M, N> is able to be converted into a std::function<real(index_t, index_t)>, since it defines real operator ()(index_t, index_t) for doing 1-indexed matrix access.

Also, is there any way to potentially clean up the first_minor‘s declaration? I dislike what I have to do to make that one work, but if I try to follow the same pattern as the other functions with a requires expression, it complains.

General advice would also be appreciated, thanks!