// SPDX-License-Identifier: Zlib
/* ------------------------------------------------------------------------- */
/*!
 *  \file       mgl_stl_memory.h
 *  \brief      MGL STLのメモリ関連の代替
 *  \date       Since: April 29, 2022. 23:19:12 JST.
 *  \author     Acerola
 */
/* ------------------------------------------------------------------------- */

#ifndef INCGUARD_MGL_STL_MEMORY_H_1651241952
#define INCGUARD_MGL_STL_MEMORY_H_1651241952

#include <memory>
#include <type_traits>

#include <mgl/memory/mgl_memory.h>

namespace MGL::STL
{
//! STL用アロケータ
template <typename T>
class Allocator
{
public:
    using value_type = T;

    Allocator() noexcept = default;
    Allocator(const Allocator &) noexcept = default;

    template <class U>
    Allocator(const Allocator<U> & /*unused*/) noexcept {};

    ~Allocator() = default;

    T *allocate(std::size_t n)
    {
        auto *ptr = reinterpret_cast<T *>(MGL::Memory::Allocate(n * sizeof(T)));
        if (ptr == nullptr)
        {
            throw std::bad_alloc();
        }

        return ptr;
    }

    void deallocate(T *p, [[maybe_unused]] std::size_t n)
    {
        MGL::Memory::Deallocate(p);
    }
};

template <class T, class U>
bool operator==(const Allocator<T> & /*unused*/, const Allocator<U> & /*unused*/) noexcept
{
    return true;
}

template <class T, class U>
bool operator!=(const Allocator<T> & /*unused*/, const Allocator<U> & /*unused*/) noexcept
{
    return false;
}

//! STL用デリータ
struct Deleter
{
    using DestroyFunc = void (*)(void *obj, size_t size);

    struct Header
    {
        DestroyFunc destroyer;
        size_t size;
    };
    static_assert(sizeof(Header) <= 16);

    void operator()(void *ptr) const
    {
        if (ptr != nullptr)
        {
            auto *top = reinterpret_cast<Header *>(ptr) - 1;

            if (top->destroyer != nullptr)
            {
                top->destroyer(top + 1, top->size);
            }

            MGL::Memory::Deallocate(top);
        }
    }
};

//! デストラクタ呼び出し用テンプレート関数
template <class T>
constexpr void NonTrivialDestroyer(void *obj, size_t size)
{
    std::destroy_n(reinterpret_cast<T *>(obj), size);
}


//! 非配列型の判別用テンプレート
template <class T>
struct UniquePtrType
{
    using UniquePtr_Single = std::unique_ptr<T, Deleter>;
};

//! 配列型の判別用テンプレート
template <class T>
struct UniquePtrType<T[]>
{
    using UniquePtr_Array = std::unique_ptr<T[], Deleter>;
};

//! サイズ指定の配列型の判別用テンプレート
template <class T, size_t n>
struct UniquePtrType<T[n]>
{
    using UniquePtr_Invalid = void;
};


//! 非配列型のユニークポインタの生成
template <class T, class... Args>
[[nodiscard]] inline typename UniquePtrType<T>::UniquePtr_Single make_unique(Args &&...args)
{
    // 必要な容量のバッファをアロケートして先頭アドレスをヘッダに設定する
    auto *header = reinterpret_cast<Deleter::Header *>(MGL::Memory::Allocate(sizeof(T) + sizeof(Deleter::Header)));
    if (header == nullptr)
    {
        throw std::bad_alloc();
    }

    // 必要ならばデストラクタの呼び出し関数を設定
    if constexpr (std::is_trivially_destructible_v<T>)
    {
        header->destroyer = nullptr;
    }
    else
    {
        header->destroyer = NonTrivialDestroyer<T>;
    }

    // サイズは常に1
    header->size = 1;

    // 生成するオブジェクトはヘッダの次のアドレス
    auto *obj = reinterpret_cast<T *>(header + 1);

    // 例外なし
    if constexpr (std::is_nothrow_constructible_v<T, Args...>)
    {
        obj = new (obj) T(std::forward<Args>(args)...);
    }
    // 例外あり
    else
    {
        try
        {
            obj = new (obj) T(std::forward<Args>(args)...);
        }
        catch (...)
        {
            MGL::Memory::Deallocate(header);
            throw;
        }
    }

    return std::unique_ptr<T, Deleter>(obj);
}

//! 配列型のユニークポインタの生成
template <class T>
[[nodiscard]] inline typename UniquePtrType<T>::UniquePtr_Array make_unique(size_t n)
{
    using U = std::remove_extent_t<T>;

    // 必要な容量のバッファをアロケートして先頭アドレスをヘッダに設定する
    auto *header = reinterpret_cast<Deleter::Header *>(MGL::Memory::Allocate((sizeof(U) * n) + sizeof(Deleter::Header)));
    if (header == nullptr)
    {
        return std::unique_ptr<T, Deleter>(nullptr);
    }

    // 必要ならばデストラクタの呼び出し関数を設定
    if constexpr (std::is_trivially_destructible_v<U>)
    {
        header->destroyer = nullptr;
    }
    else
    {
        header->destroyer = NonTrivialDestroyer<U>;
    }

    // ヘッダにサイズを設定
    header->size = n;

    // 生成するオブジェクトはヘッダの次のアドレス
    auto *obj = reinterpret_cast<U *>(header + 1);
    U *top = nullptr;

    // 例外なし
    if constexpr (std::is_nothrow_constructible_v<U>)
    {
        for (size_t i = 0; i < n; i++)
        {
            auto *ptr = obj + i;
            ptr = new (ptr) U();

            if (i == 0)
            {
                top = ptr;
            }
        }
    }
    // 例外あり
    else
    {
        for (size_t i = 0; i < n; i++)
        {
            U *ptr = nullptr;
            try
            {
                ptr = new (obj + i) U();
            }
            catch (...)
            {
                if (top != nullptr)
                {
                    for (auto j = static_cast<int64_t>(i) - 1; j >= 0; j--)
                    {
                        std::destroy_at(top + j);
                    }
                }

                MGL::Memory::Deallocate(header);
                throw;
            }

            if (i == 0)
            {
                top = ptr;
            }
        }
    }

    return std::unique_ptr<T, Deleter>(top);
}

// サイズ指定の配列型のユニークポインタは扱えない
template <class T, class... Args>
constexpr typename UniquePtrType<T>::UniquePtr_Invalid make_unique(Args &&...) = delete;

//! MGLのアロケータを利用するユニークポインタ
template <class T>
using unique_ptr = std::unique_ptr<T, Deleter>;

//! シェアードポインタの生成
template <class T, class... Args>
[[nodiscard]] constexpr auto make_shared(Args &&...args)
{
    const Allocator<void> alloc;
    return std::allocate_shared<T>(alloc, std::forward<Args>(args)...);
}

}    // namespace MGL::STL

#endif    // INCGUARD_MGL_STL_MEMORY_H_1651241952

// vim: et ts=4 sw=4 sts=4
