#ifndef TOS__NET__TCP_HPP__ #define TOS__NET__TCP_HPP__ #include <cstdint> #include <memory> #include "net/ipv4.hpp" #include "utils/span.hpp" #include "utils/struct.hpp" namespace net { enum class TCPSocketState { Closed, Listen, SynSent, SynReceived, Established, FinWait1, FinWait2, Closing, TimeWait, CloseWait, // LastAck, }; enum class TCPFlags : std::uint8_t { None = 0, FIN = 1, SYN = 2, RST = 4, PSH = 8, ACK = 16, URG = 32, ECE = 64, CWR = 128, // NS = 256 }; struct TCPFlagsSubset { constexpr TCPFlagsSubset(TCPFlags flags) : val(static_cast<std::uint8_t>(flags)) { } constexpr TCPFlagsSubset(std::uint8_t val) : val(val) { } constexpr TCPFlagsSubset(int val) : val(val) { } constexpr operator std::uint8_t() const { return val; } std::uint8_t val; }; constexpr bool operator==(TCPFlagsSubset lhs, TCPFlagsSubset rhs) { return lhs.val == rhs.val; } constexpr bool operator!=(TCPFlagsSubset lhs, TCPFlagsSubset rhs) { return !(lhs == rhs); } constexpr TCPFlagsSubset operator+(TCPFlags flags) { return TCPFlagsSubset{ static_cast<std::uint8_t>(flags) }; } constexpr TCPFlagsSubset operator&(TCPFlags lhs, TCPFlagsSubset rhs) { return TCPFlagsSubset { static_cast<std::uint8_t>(lhs) & rhs.val }; } constexpr TCPFlagsSubset operator|(TCPFlags lhs, TCPFlagsSubset rhs) { return TCPFlagsSubset { static_cast<std::uint8_t>(lhs) | rhs.val }; } constexpr TCPFlagsSubset operator&(TCPFlagsSubset lhs, TCPFlags rhs) { return TCPFlagsSubset { lhs.val & static_cast<std::uint8_t>(rhs) }; } constexpr TCPFlagsSubset operator|(TCPFlagsSubset lhs, TCPFlags rhs) { return TCPFlagsSubset { lhs.val | static_cast<std::uint8_t>(rhs) }; } constexpr TCPFlagsSubset operator&(TCPFlags lhs, TCPFlags rhs) { return static_cast<TCPFlagsSubset>( static_cast<std::uint8_t>(lhs) & static_cast<std::uint8_t>(rhs) ); } constexpr TCPFlagsSubset operator|(TCPFlags lhs, TCPFlags rhs) { return static_cast<TCPFlagsSubset>( static_cast<std::uint8_t>(lhs) | static_cast<std::uint8_t>(rhs) ); } struct TCPHeader { NetField<+0, 2, std::uint16_t> srcPort; NetField<+2, 2, std::uint16_t> dstPort; NetField<+4, 4, std::uint32_t> seqNum; NetField<+8, 4, std::uint32_t> ackNum; NetField<+12, 1, std::uint8_t> offsetAndReserved; NetField<+13, 1, std::uint8_t> flags; NetField<+14, 2, std::uint16_t> windowSize; NetField<+16, 2, std::uint16_t> checksum; NetField<+18, 2, std::uint16_t> urgentPtr; NetField<+20, 4, std::uint32_t> options; explicit TCPHeader(std::uint8_t data[]) : srcPort(data), dstPort(data), seqNum(data), ackNum(data), offsetAndReserved(data), flags(data), windowSize(data), checksum(data), urgentPtr(data), options(data) { } static constexpr std::size_t size() { using lastField = decltype(options); return lastField::offset + lastField::size; } }; struct TCPPseudoHeader { NetField<+0, 4, std::uint32_t> srcIP; NetField<+4, 4, std::uint32_t> dstIP; NetField<+8, 2, std::uint16_t> protocol; NetField<+10, 2, std::uint16_t> totalLength; explicit TCPPseudoHeader(std::uint8_t data[]) : srcIP(data), dstIP(data), protocol(data), totalLength(data) { } static constexpr std::size_t size() { using lastField = decltype(totalLength); return lastField::offset + lastField::size; } }; class TCPSocket; class TCPProvider; class TCPHandler { public: virtual ~TCPHandler() = default; public: // Returns `true` to signal that connection is still alive and shouldn't be // terminated. virtual bool handleTCPMessage(TCPSocket &socket, span<std::uint8_t> msg) = 0; }; class TCPSocket { friend class TCPProvider; public: TCPSocket(TCPProvider &backend); virtual ~TCPSocket() = default; public: virtual bool handleTCPMessage(span<std::uint8_t> msg); virtual void send(span<const std::uint8_t> buf); virtual void disconnect(); protected: NetOrder<std::uint16_t> remotePort; NetOrder<std::uint32_t> remoteIP; NetOrder<std::uint16_t> localPort; NetOrder<std::uint32_t> localIP; std::uint32_t seqNum; std::uint32_t ackNum; TCPProvider &backend; TCPHandler *handler; TCPSocketState state; }; class TCPProvider : IPHandler { public: TCPProvider(IPProvider &backend); public: virtual bool onIPReceived(NetOrder<std::uint32_t> srcIP, NetOrder<std::uint32_t> dstIP, span<std::uint8_t> msg) override; virtual TCPSocket & connect(NetOrder<std::uint32_t> ip, std::uint16_t port); virtual void disconnect(TCPSocket &socket); virtual void send(TCPSocket &socket, span<const std::uint8_t> buf, TCPFlagsSubset flags = TCPFlags::None); virtual TCPSocket & listen(std::uint16_t port); virtual void bind(TCPSocket &socket, TCPHandler *handler); private: std::unique_ptr<TCPSocket> sockets[65536]; std::uint16_t numSockets; std::uint16_t freePort; }; } #endif // TOS__NET__TCP_HPP__