From 7d1343a0d7e005e945a60219164c4b97a5259b5e Mon Sep 17 00:00:00 2001 From: hkc Date: Mon, 8 Jan 2024 16:11:18 +0300 Subject: [PATCH] Final? --- wsvpn.c | 90 +++++++++++++++++++++++++++++++++++++++++++++++-------- wsvpn.lua | 15 ++++++++-- 2 files changed, 90 insertions(+), 15 deletions(-) diff --git a/wsvpn.c b/wsvpn.c index 2f11ecb..5192596 100644 --- a/wsvpn.c +++ b/wsvpn.c @@ -10,6 +10,52 @@ #define MAX_OPEN_CHANNELS 128 #define MAX_PACKET_SIZE 32767 +/* +```c +// ALL integers are network-endian + +struct msg_info { // side: both. safe to ignore + uint8_t code; // 0x49, 'I' + uint16_t len; + char msg[1024]; // $.len bytes sent +}; + +struct msg_res_error { + uint8_t code; // 0x45, 'E' + uint16_t req_id; // request ID that caused that error + uint16_t len; + char msg[1024]; // $.len bytes sent +}; + +struct msg_address { // side: server + uint8_t code; // 0x41, 'A' + uint8_t size; + char name[256]; // $.size long +}; + +struct msg_res_success { // side: server + uint8_t code; // 0x52, 'R' + uint16_t req_id; // request ID we're replying to + void *data; // packet-specific +}; + +struct msg_transmission { // side: server + uint8_t code; // 0x54, 'T' + uint16_t channel; + uint16_t replyChannel; + uint16_t size; + void *data; +}; + +struct msg_req_open { // side: client + uint8_t code; // 0x4f, 'O' + uint16_t req_id; // incremental request ID + uint16_t channel; // channel to be open +}; + +``` +*/ + struct client { struct mg_connection *connection; uint16_t open_channels[MAX_OPEN_CHANNELS]; @@ -24,6 +70,7 @@ static void on_ws_connect(struct mg_connection *connection, struct mg_http_messa static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data); static void on_ws_disconnect(struct mg_connection *connection, void *data); +bool client_is_open(struct client *client, uint16_t channel); static void modem_open(struct client *client, uint16_t request_id, uint16_t channel); static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel); static void modem_close(struct client *client, uint16_t request_id, uint16_t channel); @@ -146,12 +193,14 @@ void ws_respond(struct mg_connection *connection, uint16_t request_id, void *dat } static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message, void *data) { - static char addr_str[256]; - mg_snprintf(addr_str, 256, "%M", mg_print_ip_port, &connection->rem); - ws_send_info(connection, "Hello, %s", addr_str); struct client *client = malloc(sizeof(struct client)); memcpy(&connection->data[0], &client, sizeof(struct client *)); client->connection = connection; + + static char buffer[256]; + buffer[0] = 'A'; + buffer[1] = snprintf(&buffer[2], 250, "wsvpn_%ld", connection->id); + mg_ws_send(connection, buffer, 2 + buffer[1], WEBSOCKET_OP_BINARY); } static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data) { @@ -171,7 +220,7 @@ static void on_ws_message(struct mg_connection *connection, struct mg_ws_message uint16_t request_id = ntohs(*(uint16_t*)&message->data.ptr[1]); switch (message->data.ptr[0]) { - case 'I': // info. We can safely ignore that channel + case 'I': // info. We can safely ignore that message break; case 'O': // open { @@ -205,7 +254,7 @@ static void on_ws_message(struct mg_connection *connection, struct mg_ws_message uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); uint16_t reply_channel = ntohs(*(uint16_t*)&message->data.ptr[5]); uint16_t data_length = ntohs(*(uint16_t*)&message->data.ptr[7]); - modem_transmit(client, request_id, channel, reply_channel, data, data_length); + modem_transmit(client, request_id, channel, reply_channel, (void*)&message->data.ptr[9], data_length); } return; default: @@ -240,13 +289,7 @@ static void modem_open(struct client *client, uint16_t request_id, uint16_t chan } static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel) { - unsigned char is_open = 0; - for (int i = 0; i < client->next_open_channel_index; i++) { - if (client->open_channels[i] == channel) { - is_open = 42; - break; - } - } + unsigned char is_open = client_is_open(client, channel) ? 42 : 0; ws_respond(client->connection, request_id, &is_open, 1); } @@ -275,7 +318,7 @@ static void modem_transmit(struct client *client, uint16_t request_id, uint16_t return; } - buffer[0] = 'M'; + buffer[0] = 'T'; buffer[1] = (channel >> 8) & 0xFF; buffer[2] = channel & 0xFF; buffer[3] = (reply_channel >> 8) & 0xFF; @@ -283,4 +326,25 @@ static void modem_transmit(struct client *client, uint16_t request_id, uint16_t buffer[5] = (size >> 8) & 0xFF; buffer[6] = size & 0xFF; memcpy(&buffer[7], data, size); + + for (struct mg_connection *conn = client->connection->mgr->conns; conn != NULL; conn = conn->next) { + if (conn->is_websocket) { + struct client *other_client = *(struct client **)&conn->data[0]; + if (other_client->connection == conn && other_client->connection != client->connection) { + if (client_is_open(other_client, channel)) { + mg_ws_send(other_client->connection, buffer, size + 7, WEBSOCKET_OP_BINARY); + } + } + } + } + ws_respond(client->connection, request_id, NULL, 0); +} + +bool client_is_open(struct client *client, uint16_t channel) { + for (int i = 0; i < client->next_open_channel_index; i++) { + if (client->open_channels[i] == channel) { + return true; + } + } + return false; } diff --git a/wsvpn.lua b/wsvpn.lua index c79bb38..edab371 100644 --- a/wsvpn.lua +++ b/wsvpn.lua @@ -38,11 +38,13 @@ local WSModem = { local serialized = textutils.serializeJSON(data) expect.range(#serialized, 0, 65535) serialized = { serialized:byte(1, 65536) } - return self._request(0x54, { + self._request(0x54, { bit.band(0xFF, bit.brshift(channel, 8)), bit.band(0xFF, channel), bit.band(0xFF, bit.brshift(replyChannel, 8)), bit.band(0xFF, replyChannel), + bit.band(0xFF, bit.brshift(#serialized, 8)), + bit.band(0xFF, #serialized), table.unpack(serialized, 1, #serialized) }) end, @@ -58,6 +60,9 @@ local WSModem = { local len, msg = self._read_u16ne(data) msg = string.char(table.unpack(msg)) os.queueEvent("wsvpn:info", msg) + elseif opcode == 0x41 then -- Set address/side + local len = table.remove(data, 1) + self.side = string.char(table.unpack(data, 1, len)) elseif opcode == 0x45 then -- Error local request_id, error_length request_id, data = self._read_u16ne(data) @@ -67,6 +72,12 @@ local WSModem = { elseif opcode == 0x52 then -- Response local request_id, response = self._read_u16ne(data) os.queueEvent("wsvpn:response", true, request_id, response) + elseif opcode == 0x54 then -- Transmission + local channel, replyChannel, dataSize, packet + channel, data = self._read_u16ne(data) + replyChannel, data = self._read_u16ne(data) + dataSize, packet = self._read_u16ne(data) + os.queueEvent("modem_message", self.side or "wsmodem_0", channel, replyChannel, textutils.unserializeJSON(string.char(table.unpack(data, 1, dataSize))), nil) else return false, string.format("Invalid opcode 0x%02x", opcode) end @@ -134,7 +145,7 @@ local WSModem = { return function(addr) local ws = assert(http.websocket(addr)) - local sock = setmetatable({ _socket = ws, _req_id = 0 }, { __index = WSModem }) + local sock = setmetatable({ _socket = ws, _req_id = 0, side = "wsmodem_unknown" }, { __index = WSModem }) for name, method in pairs(WSModem) do sock[name] = function(...) return method(sock, ...) end end