c++ – Find the mean of two values, take two

This is my generalisation of std::midpoint incorporating advice received for Find the mean of two values. As well as supporting arithmetic types and pointers, as std::midpoint does, it also supports iterators, complex numbers and user-defined types such as bignums, rationals and fixed-point numbers.

As well as advice from the first review, I’ve made a couple of other changes:

  • reject bool as argument type

  • pedantic handling of pointer difference where the result is too large for std::ptrdiff_t:

    If an array is so large (greater than PTRDIFF_MAX elements, but less
    than SIZE_MAX bytes), that the difference between two pointers may not
    be representable as std::ptrdiff_t, the result of subtracting two such
    pointers is undefined. — cppreference (CC-BY-SA 3.0)

#include <array>
#include <cmath>
#include <complex>
#include <concepts>
#include <cstddef>
#include <iterator>
#include <type_traits>
#include <utility>

namespace toby
{
namespace detail
{
    // A point on an affine line can be compared with another, and has
    // a difference type (which may be the same type).
    template<typename T>
    concept affine_point = std::regular<T> && requires(T a, T b) {
        a < b;
        a + (b - a);
    };

    // Subtraction helper, for pedantic handling of very large arrays
    template<typename T>
    auto distance(const T& a, const T& b) {
        return b - a;
    }
    template<typename T>
    requires ( sizeof (T) < SIZE_MAX / PTRDIFF_MAX )
    std::size_t distance(const T* a, const T* b) {
        if (b < a) {
            std::swap(a, b);
        }
        // If an array has more than PTRDIFF_MAX elements,
        // subtraction is undefined if the result is not
        // representable as std::ptrdiff_t.
        std::size_t gap = 1;
        while (a + gap < b - gap) {
            gap *= 2;
        }
        // (b - gap - a) promotes to size_t if necessary
        return b - gap - a + gap;
    }

    void midpoint(bool, bool) = delete;

    template<affine_point T>
    constexpr T midpoint(const T& a, const T& b)
    {
        if (a == b) {
            // This ensures infinities are correctly returned.
            return a;
        }

        if constexpr (std::is_signed_v<T>) {
            if ((a < 0) != (b < 0)) {
                // Values are opposite sign; avoid overflow when
                // magnitudes are large.
                return (a + b) / 2;
            }
        }

        if (a < b) {
            return a + distance(a, b) / 2;
        } else {
            return b + distance(b, a) / 2;
        }
    }

    // Iterators
    // If not random-access, then a MUST be before b
    template<std::input_or_output_iterator Iter, std::sentinel_for<Iter> S>
    constexpr Iter midpoint(Iter a, const S& b)
    {
        std::ranges::advance(a, std::ranges::distance(a, b) / 2);
        return a;
    }

    // Aggregate types follow
    // Pattern can be extended, e.g. for popular geometry types

    template<affine_point T>
    constexpr std::complex<T> midpoint(const std::complex<T>& a, const std::complex<T>& b)
    {
        return {
            midpoint(a.real(), b.real()),
            midpoint(a.imag(), b.imag())
        };
    }

    template<affine_point T, std::size_t N>
    constexpr std::array<T,N> midpoint(const std::array<T,N>& a, const std::array<T,N>& b)
    {
        std::array<T,N> result;
        auto f = ()(auto&& x, auto&& y) { return midpoint(x,y); };
        std::transform(a.begin(), a.end(), b.begin(), result.begin(), f);
        return result;
    }
}

using detail::midpoint;
}

// Tests

#include <gtest/gtest.h>

using toby::midpoint;

#include <climits>
TEST(midpoint, int)
{
    EXPECT_EQ(midpoint(0, 0), 0);
    EXPECT_EQ(midpoint(0, 1), 0);
    EXPECT_EQ(midpoint(0, 2), 1);
    EXPECT_EQ(midpoint(1, 3), 2);
    EXPECT_EQ(midpoint(4, 1), 2);
    EXPECT_EQ(midpoint(INT_MIN, 0), INT_MIN/2);
    EXPECT_EQ(midpoint(INT_MAX, 0), INT_MAX/2);
    EXPECT_EQ(midpoint(INT_MAX, -INT_MAX), 0);
}

#include <limits>
TEST(midpoint, double)
{
    static constexpr auto inf = std::numeric_limits<double>::infinity();
    static constexpr auto nan = std::numeric_limits<double>::quiet_NaN();
    EXPECT_EQ(midpoint(0.0, 0.0), 0.0);
    EXPECT_EQ(midpoint(1.0, 2.0), 1.5);
    EXPECT_EQ(midpoint(1.0, inf), inf);
    EXPECT_EQ(midpoint(1.0, -inf), -inf);
    EXPECT_EQ(midpoint(inf, inf), inf);
    EXPECT_EQ(midpoint(-inf, -inf), -inf);
    EXPECT_TRUE(std::isnan(midpoint(inf, -inf)));
    EXPECT_TRUE(std::isnan(midpoint(nan, 0.0)));
    EXPECT_TRUE(std::isnan(midpoint(0.0, nan)));
    EXPECT_TRUE(std::isnan(midpoint(nan, nan)));
}

#include <complex>
TEST(midpoint, complex)
{
    auto const a = std::complex{2,10};
    auto const b = std::complex{0,20};
    auto const c = std::complex{1,15};
    EXPECT_EQ(midpoint(a, b), c);
}

TEST(midpoint, pointer)
{
    char const s(50) = {};
    EXPECT_EQ(midpoint(s+1, s+25), s+13);
    EXPECT_EQ(midpoint(s+25, s+1), s+13);
}

#include <string_view>
TEST(midpoint, iterator)
{
    auto const s = std::string_view{"abcdefghijklmnopqrstuvwxyz"};
    EXPECT_EQ(*midpoint(s.begin(), s.end()), 'n');
    EXPECT_EQ(*midpoint(s.end(), s.begin()), 'n');
}

#include <list>
TEST(midpoint, bidi_iterator)
{
    auto const s = std::string_view{"abcdefghijklmnopqrstuvwxyz"};
    auto const l = std::list(s.begin(), s.end());
    EXPECT_EQ(*midpoint(l.begin(), l.end()), 'n');
}

#include <forward_list>
TEST(midpoint, forward_iterator)
{
    auto const s = std::string_view{"abcdefghijklmnopqrstuvwxyz"};
    auto const l = std::forward_list(s.begin(), s.end());
    EXPECT_EQ(*midpoint(l.begin(), l.end()), 'n');
}

#include <array>
TEST(midpoint, std_array)
{
    auto const a = std::array{ 0, 10, 20};
    auto const b = std::array{10, 10, 10};
    auto const c = std::array{ 5, 10, 15};
    EXPECT_EQ(midpoint(a, b), c);
}