#include "net/tcp.hpp"
#include <memory>
#include <utility>
#include "utils/conv.hpp"
using namespace net;
TCPSocket::TCPSocket(TCPProvider &backend)
: backend(backend), handler(nullptr), state(TCPSocketState::Closed)
{
}
bool
TCPSocket::handleTCPMessage(span<std::uint8_t> msg)
{
if (handler != nullptr) {
return handler->handleTCPMessage(*this, msg);
}
return false;
}
void
TCPSocket::send(span<const std::uint8_t> buf)
{
auto state = const_cast<volatile TCPSocketState *>(&this->state);
// Wait. (XXX: Potentially forever...)
while (*state != TCPSocketState::Established) {
// XXX: we are probably reading memory of a dead object here...
if (*state == TCPSocketState::Closed) {
return;
}
}
backend.send(*this, buf, TCPFlags::ACK);
}
void
TCPSocket::disconnect()
{
backend.disconnect(*this);
}
TCPProvider::TCPProvider(IPProvider &backend)
: IPHandler(backend, 0x06), numSockets(0), freePort(1024)
{
}
bool
TCPProvider::onIPReceived(NetOrder<std::uint32_t> srcIP,
NetOrder<std::uint32_t> dstIP,
span<std::uint8_t> msg)
{
if (msg.size() < 20) {
return false;
}
TCPHeader tcpHdr(msg.data());
auto flags = static_cast<TCPFlags>(tcpHdr.flags.BE);
TCPSocket *socket = nullptr;
for (std::uint16_t i = 0; i < numSockets && socket == nullptr; ++i) {
if (sockets[i]->localPort == tcpHdr.dstPort &&
sockets[i]->localIP == dstIP &&
sockets[i]->state == TCPSocketState::Listen &&
(flags & (TCPFlags::SYN | TCPFlags::ACK)) == TCPFlags::SYN) {
socket = sockets[i].get();
} else if (sockets[i]->localPort == tcpHdr.dstPort &&
sockets[i]->localIP == dstIP &&
sockets[i]->remotePort == tcpHdr.srcPort &&
sockets[i]->remoteIP == srcIP) {
socket = sockets[i].get();
}
}
if (socket == nullptr) {
TCPSocket socket(*this);
socket.remotePort = tcpHdr.srcPort;
socket.remoteIP = srcIP;
socket.localPort = tcpHdr.dstPort;
socket.localIP = dstIP;
socket.seqNum = ntoh(tcpHdr.ackNum.BE);
socket.ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
send(socket, {}, TCPFlags::RST);
return false;
}
bool reset = false;
if ((flags & TCPFlags::RST) != TCPFlags::None) {
socket->state = TCPSocketState::Closed;
}
if (socket->state != TCPSocketState::Closed) {
switch (flags & (TCPFlags::SYN | TCPFlags::ACK | TCPFlags::FIN)) {
case +TCPFlags::SYN:
if (socket->state != TCPSocketState::Listen) {
reset = true;
break;
}
socket->state = TCPSocketState::SynReceived;
socket->remotePort = tcpHdr.srcPort;
socket->remoteIP = srcIP;
socket->ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
socket->seqNum = 0xbeefcafe;
send(*socket, {}, TCPFlags::SYN | TCPFlags::ACK);
++socket->seqNum;
break;
case TCPFlags::SYN | TCPFlags::ACK:
if (socket->state == TCPSocketState::SynSent) {
socket->state = TCPSocketState::Established;
socket->ackNum = ntoh(tcpHdr.seqNum.BE) + 1;
++socket->seqNum;
send(*socket, {}, TCPFlags::ACK);
}
else
reset = true;
break;
case TCPFlags::SYN | TCPFlags::FIN:
case TCPFlags::SYN | TCPFlags::FIN | TCPFlags::ACK:
reset = true;
break;
case +TCPFlags::FIN:
case TCPFlags::FIN | TCPFlags::ACK:
if (socket->state == TCPSocketState::Established) {
socket->state = TCPSocketState::CloseWait;
++socket->ackNum;
send(*socket, {}, TCPFlags::ACK);
send(*socket, {}, TCPFlags::FIN | TCPFlags::ACK);
} else if (socket->state == TCPSocketState::CloseWait) {
socket->state = TCPSocketState::Closed;
} else if (socket->state == TCPSocketState::FinWait1 ||
socket->state == TCPSocketState::FinWait2) {
socket->state = TCPSocketState::Closed;
++socket->ackNum;
send(*socket, {}, TCPFlags::ACK);
} else {
reset = true;
}
break;
case +TCPFlags::ACK:
if (socket->state == TCPSocketState::SynReceived) {
socket->state = TCPSocketState::Established;
return false;
} else if (socket->state == TCPSocketState::FinWait1) {
socket->state = TCPSocketState::FinWait2;
return false;
} else if (socket->state == TCPSocketState::CloseWait) {
socket->state = TCPSocketState::Closed;
break;
}
if (flags == TCPFlags::ACK) {
break;
}
// no break, because of piggybacking
default:
if (ntoh(tcpHdr.seqNum.BE) == socket->ackNum) {
const int headerSize = (tcpHdr.offsetAndReserved.BE >> 4)*4;
reset = !socket->handleTCPMessage(msg.after(headerSize));
if (!reset) {
socket->ackNum += msg.size() - headerSize;
send(*socket, {}, TCPFlags::ACK);
}
} else {
// data in wrong order
reset = true;
}
break;
}
}
if (reset) {
send(*socket, {}, TCPFlags::RST);
}
if (socket->state == TCPSocketState::Closed) {
for (std::uint16_t i = 0U; i < numSockets; ++i) {
if (sockets[i].get() == socket) {
sockets[i] = std::move(sockets[--numSockets]);
delete socket;
break;
}
}
}
return false;
}
void
TCPProvider::send(TCPSocket &socket, span<const std::uint8_t> buf,
TCPFlagsSubset flags)
{
std::uint16_t totalLength = buf.size() + TCPHeader::size();
std::uint16_t lengthInclPHdr = totalLength + TCPPseudoHeader::size();
std::unique_ptr<std::uint8_t[]> buffer(new std::uint8_t[lengthInclPHdr]);
TCPPseudoHeader phdr(buffer.get());
TCPHeader tcpHdr(buffer.get() + TCPPseudoHeader::size());
std::uint8_t *payload = buffer.get()
+ TCPHeader::size()
+ TCPPseudoHeader::size();
tcpHdr.offsetAndReserved =
netOrder(static_cast<std::uint8_t>(TCPHeader::size()/4 << 4));
tcpHdr.srcPort = socket.localPort;
tcpHdr.dstPort = socket.remotePort;
tcpHdr.ackNum = toNetOrder(socket.ackNum);
tcpHdr.seqNum = toNetOrder(socket.seqNum);
tcpHdr.flags = netOrder(flags.val);
tcpHdr.windowSize = toNetOrder<std::uint16_t>(0xffff);
tcpHdr.urgentPtr = toNetOrder<std::uint16_t>(0);
tcpHdr.options = toNetOrder<std::uint32_t>(
((flags & TCPFlags::SYN) != TCPFlags::None) ? 0x020405b4 : 0
);
socket.seqNum += buf.size();
for (int i = 0; i < buf.size(); ++i) {
payload[i] = buf[i];
}
phdr.srcIP = socket.localIP;
phdr.dstIP = socket.remoteIP;
phdr.protocol = toNetOrder<std::uint16_t>(0x0006);
phdr.totalLength = toNetOrder(totalLength);
tcpHdr.checksum = toNetOrder<std::uint16_t>(0);
tcpHdr.checksum = IPProvider::checksum({ buffer.get(), lengthInclPHdr });
IPHandler::send(socket.remoteIP,
{ buffer.get() + TCPPseudoHeader::size(), totalLength });
}
TCPSocket &
TCPProvider::connect(NetOrder<std::uint32_t> ip, std::uint16_t port)
{
std::unique_ptr<TCPSocket> socket(new TCPSocket(*this));
socket->remotePort = toNetOrder(port);
socket->remoteIP = ip;
socket->localPort = toNetOrder(freePort++);
socket->localIP = backend.getIPAddress();
TCPSocket &s = *socket;
sockets[numSockets++] = std::move(socket);
s.state = TCPSocketState::SynSent;
s.seqNum = 0xbeefcafe;
send(s, {}, TCPFlags::SYN);
return s;
}
void
TCPProvider::disconnect(TCPSocket &socket)
{
socket.state = TCPSocketState::FinWait1;
send(socket, {}, TCPFlags::FIN | TCPFlags::ACK);
++socket.seqNum;
}
TCPSocket &
TCPProvider::listen(std::uint16_t port)
{
std::unique_ptr<TCPSocket> socket(new TCPSocket(*this));
socket->state = TCPSocketState::Listen;
socket->localIP = backend.getIPAddress();
socket->localPort = toNetOrder(port);
sockets[numSockets++] = std::move(socket);
return *sockets[numSockets - 1];
}
void
TCPProvider::bind(TCPSocket &socket, TCPHandler *handler)
{
socket.handler = handler;
}
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"
Clone this repository using HTTP(S):
git clone https://code.reversed.top/user/xaizek/tos
Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@code.reversed.top/user/xaizek/tos
You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a
pull request:
... clone the repository ...
... make some changes and some commits ...
git push origin master