xaizek / tos (License: GPLv3 only) (since 2018-12-07)
This is an alternative version of sources presented as part of Write Your Own OS video tutorial by Viktor Engelmann.
<root> / src / net / tcp.cpp (274540fb09b02b97adb963e9bc7b6ac05440e70b) (9,129B) (mode 100644) [raw]
#include "net/tcp.hpp"

#include <memory>
#include <utility>

#include "utils/conv.hpp"

using namespace net;

TCPSocket::TCPSocket(TCPProvider &backend)
    : backend(backend), handler(nullptr), state(TCPSocketState::Closed)
{
}

bool
TCPSocket::handleTCPMessage(span<std::uint8_t> msg)
{
    if (handler != nullptr) {
        return handler->handleTCPMessage(*this, msg);
    }
    return false;
}

void
TCPSocket::send(span<const std::uint8_t> buf)
{
    auto state = const_cast<volatile TCPSocketState *>(&this->state);
    // Wait.  (XXX: Potentially forever...)
    while (*state != TCPSocketState::Established) {
        // XXX: we are probably reading memory of a dead object here...
        if (*state == TCPSocketState::Closed) {
            return;
        }
    }
    backend.send(*this, buf, TCPFlags::ACK);
}

void
TCPSocket::disconnect()
{
    backend.disconnect(*this);
}

TCPProvider::TCPProvider(IPProvider &backend)
    : IPHandler(backend, 0x06), numSockets(0), freePort(1024)
{
}

bool
TCPProvider::onIPReceived(NetOrder<std::uint32_t> srcIP,
                          NetOrder<std::uint32_t> dstIP,
                          span<std::uint8_t> msg)
{
    if (msg.size() < 20) {
        return false;
    }

    TCPHeader tcpHdr(msg.data());
    auto flags = static_cast<TCPFlags>(tcpHdr.flags.BE);

    TCPSocket *socket = nullptr;
    for (std::uint16_t i = 0; i < numSockets && socket == nullptr; ++i) {
        if (sockets[i]->localPort == tcpHdr.dstPort &&
            sockets[i]->localIP == dstIP &&
            sockets[i]->state == TCPSocketState::Listen &&
            (flags & (TCPFlags::SYN | TCPFlags::ACK)) == TCPFlags::SYN) {
            socket = sockets[i].get();
        } else if (sockets[i]->localPort == tcpHdr.dstPort &&
                   sockets[i]->localIP == dstIP &&
                   sockets[i]->remotePort == tcpHdr.srcPort &&
                   sockets[i]->remoteIP == srcIP) {
            socket = sockets[i].get();
        }
    }

    if (socket == nullptr) {
        TCPSocket socket(*this);
        socket.remotePort = tcpHdr.srcPort;
        socket.remoteIP = srcIP;
        socket.localPort = tcpHdr.dstPort;
        socket.localIP = dstIP;
        socket.seqNum = ntoh(tcpHdr.ackNum.BE);
        socket.ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
        send(socket, {}, TCPFlags::RST);
        return false;
    }

    bool reset = false;

    if ((flags & TCPFlags::RST) != TCPFlags::None) {
        socket->state = TCPSocketState::Closed;
    }

    if (socket->state != TCPSocketState::Closed) {
        switch (flags & (TCPFlags::SYN | TCPFlags::ACK | TCPFlags::FIN)) {
            case +TCPFlags::SYN:
                if (socket->state != TCPSocketState::Listen) {
                    reset = true;
                    break;
                }

                socket->state = TCPSocketState::SynReceived;
                socket->remotePort = tcpHdr.srcPort;
                socket->remoteIP = srcIP;
                socket->ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
                socket->seqNum = 0xbeefcafe;
                send(*socket, {}, TCPFlags::SYN | TCPFlags::ACK);
                ++socket->seqNum;
                break;

            case TCPFlags::SYN | TCPFlags::ACK:
                if (socket->state == TCPSocketState::SynSent) {
                    socket->state = TCPSocketState::Established;
                    socket->ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
                    ++socket->seqNum;
                    send(*socket, {}, TCPFlags::ACK);
                }
                else
                    reset = true;
                break;


            case TCPFlags::SYN | TCPFlags::FIN:
            case TCPFlags::SYN | TCPFlags::FIN | TCPFlags::ACK:
                reset = true;
                break;


            case +TCPFlags::FIN:
            case TCPFlags::FIN | TCPFlags::ACK:
                if (socket->state == TCPSocketState::Established) {
                    socket->state = TCPSocketState::CloseWait;
                    ++socket->ackNum;
                    send(*socket, {}, TCPFlags::ACK);
                    send(*socket, {}, TCPFlags::FIN | TCPFlags::ACK);
                } else if (socket->state == TCPSocketState::CloseWait) {
                    socket->state = TCPSocketState::Closed;
                } else if (socket->state == TCPSocketState::FinWait1 ||
                           socket->state == TCPSocketState::FinWait2) {
                    socket->state = TCPSocketState::Closed;
                    ++socket->ackNum;
                    send(*socket, {}, TCPFlags::ACK);
                } else {
                    reset = true;
                }
                break;


            case +TCPFlags::ACK:
                if (socket->state == TCPSocketState::SynReceived) {
                    socket->state = TCPSocketState::Established;
                    return false;
                } else if (socket->state == TCPSocketState::FinWait1) {
                    socket->state = TCPSocketState::FinWait2;
                    return false;
                } else if (socket->state == TCPSocketState::CloseWait) {
                    socket->state = TCPSocketState::Closed;
                    break;
                }

                if (flags == TCPFlags::ACK) {
                    break;
                }

                // no break, because of piggybacking

            default:
                if (ntoh(tcpHdr.seqNum.BE) == socket->ackNum) {
                    const int headerSize = (tcpHdr.offsetAndReserved.BE >> 4)*4;
                    reset = !socket->handleTCPMessage(msg.after(headerSize));
                    if (!reset) {
                        socket->ackNum += msg.size() - headerSize;
                        send(*socket, {}, TCPFlags::ACK);
                    }
                } else {
                    // data in wrong order
                    reset = true;
                }
                break;
        }
    }

    if (reset) {
        send(*socket, {}, TCPFlags::RST);
    }

    if (socket->state == TCPSocketState::Closed) {
        for (std::uint16_t i = 0U; i < numSockets; ++i) {
            if (sockets[i].get() == socket) {
                sockets[i] = std::move(sockets[--numSockets]);
                delete socket;
                break;
            }
        }
    }

    return false;
}

void
TCPProvider::send(TCPSocket &socket, span<const std::uint8_t> buf,
                  TCPFlagsSubset flags)
{
    std::uint16_t totalLength = buf.size() + TCPHeader::size();
    std::uint16_t lengthInclPHdr = totalLength + TCPPseudoHeader::size();

    std::unique_ptr<std::uint8_t[]> buffer(new std::uint8_t[lengthInclPHdr]);

    TCPPseudoHeader phdr(buffer.get());
    TCPHeader tcpHdr(buffer.get() + TCPPseudoHeader::size());
    std::uint8_t *payload = buffer.get()
                          + TCPHeader::size()
                          + TCPPseudoHeader::size();

    tcpHdr.offsetAndReserved =
        netOrder(static_cast<std::uint8_t>(TCPHeader::size()/4 << 4));
    tcpHdr.srcPort = socket.localPort;
    tcpHdr.dstPort = socket.remotePort;

    tcpHdr.ackNum = toNetOrder(socket.ackNum);
    tcpHdr.seqNum = toNetOrder(socket.seqNum);
    tcpHdr.flags = netOrder(flags.val);
    tcpHdr.windowSize = toNetOrder<std::uint16_t>(0xffff);
    tcpHdr.urgentPtr = toNetOrder<std::uint16_t>(0);

    tcpHdr.options = toNetOrder<std::uint32_t>(
        ((flags & TCPFlags::SYN) != TCPFlags::None) ? 0x020405b4 : 0
    );

    socket.seqNum += buf.size();

    for (int i = 0; i < buf.size(); ++i) {
        payload[i] = buf[i];
    }

    phdr.srcIP = socket.localIP;
    phdr.dstIP = socket.remoteIP;
    phdr.protocol = toNetOrder<std::uint16_t>(0x0006);
    phdr.totalLength = toNetOrder(totalLength);

    tcpHdr.checksum = toNetOrder<std::uint16_t>(0);
    tcpHdr.checksum = IPProvider::checksum({ buffer.get(), lengthInclPHdr });

    IPHandler::send(socket.remoteIP,
                    { buffer.get() + TCPPseudoHeader::size(), totalLength });
}

TCPSocket &
TCPProvider::connect(NetOrder<std::uint32_t> ip, std::uint16_t port)
{
    std::unique_ptr<TCPSocket> socket(new TCPSocket(*this));

    socket->remotePort = toNetOrder(port);
    socket->remoteIP = ip;
    socket->localPort = toNetOrder(freePort++);
    socket->localIP = backend.getIPAddress();

    TCPSocket &s = *socket;
    sockets[numSockets++] = std::move(socket);

    s.state = TCPSocketState::SynSent;
    s.seqNum = 0xbeefcafe;
    send(s, {}, TCPFlags::SYN);

    return s;
}

void
TCPProvider::disconnect(TCPSocket &socket)
{
    socket.state = TCPSocketState::FinWait1;
    send(socket, {}, TCPFlags::FIN | TCPFlags::ACK);
    ++socket.seqNum;
}

TCPSocket &
TCPProvider::listen(std::uint16_t port)
{
    std::unique_ptr<TCPSocket> socket(new TCPSocket(*this));

    socket->state = TCPSocketState::Listen;
    socket->localIP = backend.getIPAddress();
    socket->localPort = toNetOrder(port);

    sockets[numSockets++] = std::move(socket);
    return *sockets[numSockets - 1];
}

void
TCPProvider::bind(TCPSocket &socket, TCPHandler *handler)
{
    socket.handler = handler;
}
Hints

Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://code.reversed.top/user/xaizek/tos

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@code.reversed.top/user/xaizek/tos

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a pull request:
... clone the repository ...
... make some changes and some commits ...
git push origin master