#include <SFML/Copyright.hpp> // LICENSE AND COPYRIGHT (C) INFORMATION

////////////////////////////////////////////////////////////
// Headers
////////////////////////////////////////////////////////////
#include "SFML/Network/IpAddress.hpp"
#include "SFML/Network/Packet.hpp"
#include "SFML/Network/SocketImpl.hpp"
#include "SFML/Network/TcpSocket.hpp"

#include "SFML/System/Err.hpp"

#include "SFML/Base/Algorithm.hpp"
#include "SFML/Base/Optional.hpp"

#include <vector>

#include <cstring>

#ifdef _MSC_VER
#pragma warning(disable : 4127) // "conditional expression is constant" generated by the FD_SET macro
#endif


namespace
{
// Define the low-level send/receive flags, which depend on the OS
#ifdef SFML_SYSTEM_LINUX
const int flags = MSG_NOSIGNAL;
#else
const int flags = 0;
#endif
} // namespace

namespace sf
{
////////////////////////////////////////////////////////////
/// \brief Structure holding the data of a pending packet
///
////////////////////////////////////////////////////////////
struct PendingPacket
{
    std::uint32_t          size{};         //!< Data of packet size
    std::size_t            sizeReceived{}; //!< Number of size bytes received so far
    std::vector<std::byte> data;           //!< Data of the packet
};


////////////////////////////////////////////////////////////
struct TcpSocket::Impl
{
    ////////////////////////////////////////////////////////////
    // Member data
    ////////////////////////////////////////////////////////////
    PendingPacket          pendingPacket;     //!< Temporary data of the packet currently being received
    std::vector<std::byte> blockToSendBuffer; //!< Buffer used to prepare data being sent from the socket
};


////////////////////////////////////////////////////////////
TcpSocket::TcpSocket(bool isBlocking) : Socket(Type::Tcp, isBlocking)
{
}


////////////////////////////////////////////////////////////
TcpSocket::~TcpSocket() = default;


////////////////////////////////////////////////////////////
TcpSocket::TcpSocket(TcpSocket&&) noexcept = default;


////////////////////////////////////////////////////////////
TcpSocket& TcpSocket::operator=(TcpSocket&&) noexcept = default;


////////////////////////////////////////////////////////////
unsigned short TcpSocket::getLocalPort() const
{
    return getLocalPortImpl("TCP socket");
}


////////////////////////////////////////////////////////////
base::Optional<IpAddress> TcpSocket::getRemoteAddress() const
{
    if (getNativeHandle() == priv::SocketImpl::invalidSocket())
    {
        priv::err() << "Attempted to get remote address of invalid TCP socket";
        return base::nullOpt;
    }

    // Retrieve information about the remote end of the socket
    priv::SockAddrIn address{};
    auto             size = address.size();

    if (!priv::SocketImpl::getPeerName(getNativeHandle(), address, size))
    {
        priv::err() << "Failed to retrieve remote address of invalid TCP socket";
        return base::nullOpt;
    }

    return base::makeOptional<IpAddress>(priv::SocketImpl::ntohl(address));
}


////////////////////////////////////////////////////////////
unsigned short TcpSocket::getRemotePort() const
{
    if (getNativeHandle() == priv::SocketImpl::invalidSocket())
    {
        priv::err() << "Attempted to get remote port of invalid TCP socket";
        return 0;
    }

    // Retrieve information about the remote end of the socket
    priv::SockAddrIn address{};
    auto             size = address.size();

    if (!priv::SocketImpl::getPeerName(getNativeHandle(), address, size))
    {
        priv::err() << "Failed to retrieve remote port of TCP socket";
        return 0;
    }

    return priv::SocketImpl::ntohs(address.sinPort());
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::connect(IpAddress remoteAddress, unsigned short remotePort, Time timeout)
{
    // Disconnect the socket if it is already connected
    if (getNativeHandle() != priv::SocketImpl::invalidSocket())
        (void)disconnect(); // Intentionally discard

    // Create the internal socket if it doesn't exist
    if (!create())
        return Status::Error;

    // Create the remote address
    priv::SockAddrIn address = priv::SocketImpl::createAddress(remoteAddress.toInteger(), remotePort);

    if (timeout <= Time::Zero)
    {
        // ----- We're not using a timeout: just try to connect -----

        // Connect the socket
        if (!priv::SocketImpl::connect(getNativeHandle(), address))
            return priv::SocketImpl::getErrorStatus();

        // Connection succeeded
        return Status::Done;
    }

    // ----- We're using a timeout: we'll need a few tricks to make it work -----

    // Save the previous blocking state
    const bool savedBlockingState = isBlocking();

    // Switch to non-blocking to enable our connection timeout
    if (savedBlockingState)
        setBlocking(false);

    // Try to connect to the remote address
    if (priv::SocketImpl::connect(getNativeHandle(), address))
    {
        // We got instantly connected! (it may no happen a lot...)
        setBlocking(savedBlockingState);
        return Status::Done;
    }

    // Get the error status
    Status status = priv::SocketImpl::getErrorStatus();

    // If we were in non-blocking mode, return immediately
    if (!savedBlockingState)
        return status;

    // Otherwise, wait until something happens to our socket (success, timeout or error)
    if (status == Socket::Status::NotReady)
    {
        // Wait for something to write on our socket (which means that the connection request has returned)
        if (priv::SocketImpl::select(getNativeHandle(), timeout.asMicroseconds()) > 0)
        {
            // At this point the connection may have been either accepted or refused.
            // To know whether it's a success or a failure, we must check the address of the connected peer

            status = getRemoteAddress().hasValue() ? Status::Done                        // Connection accepted
                                                   : priv::SocketImpl::getErrorStatus(); // Connection refused
        }
        else
        {
            // Failed to connect before timeout is over
            status = priv::SocketImpl::getErrorStatus();
        }
    }

    // Switch back to blocking mode
    setBlocking(true);

    return status;
}


////////////////////////////////////////////////////////////
bool TcpSocket::disconnect()
{
    // Close the socket
    const bool result = close();

    // Reset the pending packet data
    m_impl->pendingPacket = PendingPacket{};

    return result;
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::send(const void* data, std::size_t size)
{
    if (!isBlocking())
        priv::err() << "Warning: Partial sends might not be handled properly.";

    std::size_t sent = 0;

    return send(data, size, sent);
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::send(const void* data, std::size_t size, std::size_t& sent)
{
    // Check the parameters
    if (!data || (size == 0))
    {
        priv::err() << "Cannot send data over the network (no data to send)";
        return Status::Error;
    }

    // Loop until every byte has been sent
    int result = 0;
    for (sent = 0; sent < size; sent += static_cast<std::size_t>(result))
    {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuseless-cast"
        // Send a chunk of data
        result = static_cast<int>(priv::SocketImpl::send(getNativeHandle(),
                                                         static_cast<const char*>(data) + sent,
                                                         static_cast<priv::SocketImpl::Size>(size - sent),
                                                         flags));
#pragma GCC diagnostic pop

        // Check for errors
        if (result < 0)
        {
            const Status status = priv::SocketImpl::getErrorStatus();

            if ((status == Status::NotReady) && sent)
                return Status::Partial;

            return status;
        }
    }

    return Status::Done;
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::receive(void* data, std::size_t size, std::size_t& received)
{
    // First clear the variables to fill
    received = 0;

    // Check the destination buffer
    if (!data)
    {
        priv::err() << "Cannot receive data from the network (the destination buffer is invalid)";
        return Status::Error;
    }

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuseless-cast"
    // Receive a chunk of bytes
    const int sizeReceived = static_cast<int>(
        priv::SocketImpl::recv(getNativeHandle(), static_cast<char*>(data), static_cast<priv::SocketImpl::Size>(size), flags));
#pragma GCC diagnostic pop

    // Check the number of bytes received
    if (sizeReceived > 0)
    {
        received = static_cast<std::size_t>(sizeReceived);
        return Status::Done;
    }

    if (sizeReceived == 0)
        return Socket::Status::Disconnected;

    return priv::SocketImpl::getErrorStatus();
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::send(Packet& packet)
{
    // TCP is a stream protocol, it doesn't preserve messages boundaries.
    // This means that we have to send the packet size first, so that the
    // receiver knows the actual end of the packet in the data stream.

    // We allocate an extra memory block so that the size can be sent
    // together with the data in a single call. This may seem inefficient,
    // but it is actually required to avoid partial send, which could cause
    // data corruption on the receiving end.

    // Get the data to send from the packet
    std::size_t size = 0;
    const void* data = packet.onSend(size);

    // First convert the packet size to network byte order
    std::uint32_t packetSize = priv::SocketImpl::htonl(static_cast<std::uint32_t>(size));

    // Allocate memory for the data block to send
    m_impl->blockToSendBuffer.resize(sizeof(packetSize) + size);

// Copy the packet size and data into the block to send
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnull-dereference" // False positive.
    std::memcpy(m_impl->blockToSendBuffer.data(), &packetSize, sizeof(packetSize));
#pragma GCC diagnostic pop
    if (size > 0)
        std::memcpy(m_impl->blockToSendBuffer.data() + sizeof(packetSize), data, size);

// These warnings are ignored here for portability, as even on Windows the
// signature of `send` might change depending on whether Win32 or MinGW is
// being used.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuseless-cast"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-conversion"
    // Send the data block
    std::size_t  sent   = 0;
    const Status status = send(m_impl->blockToSendBuffer.data() + packet.getSendPos(),
                               static_cast<priv::SocketImpl::Size>(m_impl->blockToSendBuffer.size() - packet.getSendPos()),
                               sent);
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop

    // In the case of a partial send, record the location to resume from
    if (status == Status::Partial)
    {
        packet.getSendPos() += sent;
    }
    else if (status == Status::Done)
    {
        packet.getSendPos() = 0;
    }

    return status;
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::receive(Packet& packet)
{
    // First clear the variables to fill
    packet.clear();

    // We start by getting the size of the incoming packet
    std::uint32_t packetSize = 0;
    std::size_t   received   = 0;
    if (m_impl->pendingPacket.sizeReceived < sizeof(m_impl->pendingPacket.size))
    {
        // Loop until we've received the entire size of the packet
        // (even a 4 byte variable may be received in more than one call)
        while (m_impl->pendingPacket.sizeReceived < sizeof(m_impl->pendingPacket.size))
        {
            char* data = reinterpret_cast<char*>(&m_impl->pendingPacket.size) + m_impl->pendingPacket.sizeReceived;
            const Status status = receive(data, sizeof(m_impl->pendingPacket.size) - m_impl->pendingPacket.sizeReceived, received);
            m_impl->pendingPacket.sizeReceived += received;

            if (status != Status::Done)
                return status;
        }

        // The packet size has been fully received
        packetSize = priv::SocketImpl::ntohl(m_impl->pendingPacket.size);
    }
    else
    {
        // The packet size has already been received in a previous call
        packetSize = priv::SocketImpl::ntohl(m_impl->pendingPacket.size);
    }

    // Loop until we receive all the packet data
    char buffer[1024]{};
    while (m_impl->pendingPacket.data.size() < packetSize)
    {
        // Receive a chunk of data
        const std::size_t sizeToGet = base::min(packetSize - m_impl->pendingPacket.data.size(), sizeof(buffer));
        const Status      status    = receive(buffer, sizeToGet, received);
        if (status != Status::Done)
            return status;

        // Append it into the packet
        if (received > 0)
        {
            m_impl->pendingPacket.data.resize(m_impl->pendingPacket.data.size() + received);
            std::byte* begin = m_impl->pendingPacket.data.data() + m_impl->pendingPacket.data.size() - received;
            std::memcpy(begin, buffer, received);
        }
    }

    // We have received all the packet data: we can copy it to the user packet
    if (!m_impl->pendingPacket.data.empty())
        packet.onReceive(m_impl->pendingPacket.data.data(), m_impl->pendingPacket.data.size());

    // Clear the pending packet data
    m_impl->pendingPacket = PendingPacket();

    return Status::Done;
}

} // namespace sf
