// x-run: ~/scripts/runc.sh % -lmongoose -Wall -Wextra #include #include #include #include #include #include #define MAX_CLIENTS 256 #define MAX_OPEN_CHANNELS 128 #define MAX_PACKET_SIZE 32767 /* ```c // ALL integers are network-endian struct msg_info { // side: both. safe to ignore uint8_t code; // 0x49, 'I' uint16_t len; char msg[1024]; // $.len bytes sent }; struct msg_res_error { uint8_t code; // 0x45, 'E' uint16_t req_id; // request ID that caused that error uint16_t len; char msg[1024]; // $.len bytes sent }; struct msg_address { // side: server uint8_t code; // 0x41, 'A' uint8_t size; char name[256]; // $.size long }; struct msg_res_success { // side: server uint8_t code; // 0x52, 'R' uint16_t req_id; // request ID we're replying to void *data; // packet-specific }; struct msg_transmission { // side: server uint8_t code; // 0x54, 'T' uint16_t channel; uint16_t replyChannel; uint16_t size; void *data; }; struct msg_req_open { // side: client uint8_t code; // 0x4f, 'O' uint16_t req_id; // incremental request ID uint16_t channel; // channel to be open }; ``` */ struct client { struct mg_connection *connection; uint16_t open_channels[MAX_OPEN_CHANNELS]; int next_open_channel_index; bool receive_all; }; struct client clients[MAX_CLIENTS] = { 0 }; static void handle_client(struct mg_connection *connection, int event_type, void *ev_data, void *fn_data); static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message, void *data); static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data); static void on_ws_disconnect(struct mg_connection *connection, void *data); bool client_is_open(struct client *client, uint16_t channel); static void modem_open(struct client *client, uint16_t request_id, uint16_t channel); static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel); static void modem_close(struct client *client, uint16_t request_id, uint16_t channel); static void modem_closeAll(struct client *client, uint16_t request_id); static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size); struct metrics { uint64_t sent_bytes; uint64_t sent_messages; uint64_t received_bytes; uint64_t received_messages; uint64_t errors; uint64_t method_calls[5]; } metrics = { 0 }; const char method_names[5][8] = { "open", "isOpen", "close", "closeAll", "transmit" }; int main(void) { const char *address = "ws://0.0.0.0:8667"; struct mg_mgr manager; mg_mgr_init(&manager); mg_http_listen(&manager, address, handle_client, NULL); printf("Listening on %s\n", address); while (1) mg_mgr_poll(&manager, 1000); mg_mgr_free(&manager); } static void handle_client(struct mg_connection *connection, int event_type, void *event_data, void *fn_data) { if (event_type == MG_EV_OPEN) { if (connection->rem.port == 0) return; memset(connection->data, 0, 32); } else if (event_type == MG_EV_HTTP_MSG) { struct mg_http_message *http_message = (struct mg_http_message *) event_data; if (mg_http_match_uri(http_message, "/open")) { mg_ws_upgrade(connection, http_message, NULL); } else if (mg_http_match_uri(http_message, "/metrics")) { mg_printf(connection, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"); mg_http_printf_chunk(connection, "# HELP ws_bytes_sent_total Number of bytes sent to clients\n"); mg_http_printf_chunk(connection, "# TYPE ws_bytes_sent_total counter\n"); mg_http_printf_chunk(connection, "ws_bytes_sent_total %ld\n", metrics.sent_bytes); mg_http_printf_chunk(connection, "# HELP ws_bytes_received_total Number of bytes received to clients\n"); mg_http_printf_chunk(connection, "# TYPE ws_bytes_received_total counter\n"); mg_http_printf_chunk(connection, "ws_bytes_received_total %ld\n", metrics.received_bytes); mg_http_printf_chunk(connection, "# HELP ws_messages_sent_total Number of messages sent to clients\n"); mg_http_printf_chunk(connection, "# TYPE ws_messages_sent_total counter\n"); mg_http_printf_chunk(connection, "ws_messages_sent_total %ld\n", metrics.sent_messages); mg_http_printf_chunk(connection, "# HELP ws_messages_received_total Number of messages received to clients\n"); mg_http_printf_chunk(connection, "# TYPE ws_messages_received_total counter\n"); mg_http_printf_chunk(connection, "ws_messages_received_total %ld\n", metrics.received_messages); mg_http_printf_chunk(connection, "# HELP ws_clients Number of active websocket clients\n"); mg_http_printf_chunk(connection, "# TYPE ws_clients gauge\n"); { int n = 0; for (struct mg_connection *conn = connection->mgr->conns; conn != NULL; conn = conn->next) { if (conn->is_websocket) { n++; } } mg_http_printf_chunk(connection, "ws_clients %d\n", n); } mg_http_printf_chunk(connection, "# HELP method_calls Times each method was called\n"); mg_http_printf_chunk(connection, "# TYPE method_calls counter\n"); for (int i = 0; i < 5; i++) { mg_http_printf_chunk(connection, "method_calls{method=\"%s\"} %ld\n", method_names[i], metrics.method_calls[i]); } mg_http_printf_chunk(connection, ""); } else { mg_http_reply(connection, 404, "", "uwu"); } } else if (event_type == MG_EV_WS_OPEN) { struct mg_http_message *http_message = (struct mg_http_message *) event_data; on_ws_connect(connection, http_message, fn_data); } else if (event_type == MG_EV_WS_MSG) { struct mg_ws_message *ws_message = (struct mg_ws_message *)event_data; on_ws_message(connection, ws_message, fn_data); } else if (event_type == MG_EV_CLOSE) { if (connection->is_websocket) { on_ws_disconnect(connection, fn_data); } } } void ws_send_error(struct client *client, uint16_t request_id, const char *fmt, ...) { static char buffer[1024]; memset(buffer, 0, 1024); va_list args; va_start(args, fmt); int text_size = vsnprintf(&buffer[5], 1019, fmt, args); va_end(args); if (text_size < 0) return; buffer[0] = 'E'; buffer[1] = (request_id >> 8) & 0xFF; buffer[2] = request_id & 0xFF; buffer[3] = (text_size >> 8) & 0xFF; buffer[4] = text_size & 0xFF; metrics.sent_bytes += 5 + text_size; metrics.sent_messages++; metrics.errors++; mg_ws_send(client->connection, buffer, 5 + text_size, WEBSOCKET_OP_BINARY); } void ws_send_info(struct client *client, const char *fmt, ...) { static char buffer[1024]; memset(buffer, 0, 1024); va_list args; va_start(args, fmt); int text_size = vsnprintf(&buffer[3], 1021, fmt, args); va_end(args); if (text_size < 0) return; buffer[0] = 'I'; buffer[1] = (text_size >> 8) & 0xFF; buffer[2] = text_size & 0xFF; metrics.sent_bytes += 3 + text_size; metrics.sent_messages++; mg_ws_send(client->connection, buffer, 3 + text_size, WEBSOCKET_OP_BINARY); } void ws_respond(struct client *client, uint16_t request_id, void *data, uint32_t size) { static char buffer[MAX_PACKET_SIZE]; assert(size < MAX_PACKET_SIZE); buffer[0] = 'R'; buffer[1] = (request_id >> 8) & 0xFF; buffer[2] = request_id & 0xFF; if (size != 0) memcpy(&buffer[3], data, size); metrics.sent_bytes += 3 + size; metrics.sent_messages++; mg_ws_send(client->connection, buffer, size + 3, WEBSOCKET_OP_BINARY); } static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message, void *data) { (void)message; (void)data; struct client *client = malloc(sizeof(struct client)); memcpy(&connection->data[0], &client, sizeof(struct client *)); client->connection = connection; static char buffer[256]; buffer[0] = 'A'; buffer[1] = snprintf(&buffer[2], 250, "wsvpn_%ld", connection->id); metrics.sent_bytes += 2 + buffer[1]; metrics.sent_messages++; mg_ws_send(connection, buffer, 2 + buffer[1], WEBSOCKET_OP_BINARY); } static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data) { (void)data; if ((message->flags & 15) != WEBSOCKET_OP_BINARY) { const char *err_str = "This server only works in binary mode. Sorry!"; mg_ws_send(connection, err_str, strlen(err_str), WEBSOCKET_OP_TEXT); connection->is_draining = 1; return; } struct client *client = *(struct client **)&connection->data[0]; assert(client->connection == connection); metrics.received_bytes += message->data.len; metrics.received_messages++; if (message->data.len == 0) return; uint16_t request_id = ntohs(*(uint16_t*)&message->data.ptr[1]); switch (message->data.ptr[0]) { case 'I': // info. We can safely ignore that message break; case 'O': // open { metrics.method_calls[0]++; uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); printf("%p[%04x] modem.open(%d)\n", (void*)client, request_id, channel); modem_open(client, request_id, channel); } return; case 'o': // isOpen { metrics.method_calls[1]++; uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); printf("%p[%04x] modem.isOpen(%d)\n", (void*)client, request_id, channel); modem_isOpen(client, request_id, channel); } return; case 'c': // close { metrics.method_calls[2]++; uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); printf("%p[%04x] modem.close(%d)\n", (void*)client, request_id, channel); modem_close(client, request_id, channel); } return; case 'C': // closeAll { metrics.method_calls[3]++; printf("%p[%04x] modem.closeAll()\n", (void*)client, request_id); modem_closeAll(client, request_id); } return; case 'T': // transmit { metrics.method_calls[4]++; uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); uint16_t reply_channel = ntohs(*(uint16_t*)&message->data.ptr[5]); uint16_t data_length = ntohs(*(uint16_t*)&message->data.ptr[7]); modem_transmit(client, request_id, channel, reply_channel, (void*)&message->data.ptr[9], data_length); } return; default: ws_send_error(client, request_id, "Unknown opcode: 0x%02x", message->data.ptr[0]); connection->is_draining = 1; return; } } static void on_ws_disconnect(struct mg_connection *connection, void *data) { (void)data; struct client *client = *(struct client **)&connection->data[0]; if (client->connection == connection) { free(client); } } static void modem_open(struct client *client, uint16_t request_id, uint16_t channel) { if (client_is_open(client, channel)) { ws_respond(client, request_id, NULL, 0); } if (client->next_open_channel_index == MAX_OPEN_CHANNELS) { ws_send_error(client, request_id, "Too many open channels"); return; } client->open_channels[client->next_open_channel_index] = channel; client->next_open_channel_index++; ws_respond(client, request_id, NULL, 0); } static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel) { unsigned char is_open = client_is_open(client, channel) ? 42 : 0; ws_respond(client, request_id, &is_open, 1); } static void modem_close(struct client *client, uint16_t request_id, uint16_t channel) { for (int i = 0; i < client->next_open_channel_index; i++) { if (client->open_channels[i] == channel) { client->open_channels[i] = client->open_channels[client->next_open_channel_index - 1]; client->next_open_channel_index--; break; } } ws_respond(client, request_id, NULL, 0); } static void modem_closeAll(struct client *client, uint16_t request_id) { client->next_open_channel_index = 0; memset(client->open_channels, 0, sizeof(uint16_t) * MAX_OPEN_CHANNELS); ws_respond(client, request_id, NULL, 0); } static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size) { static uint8_t buffer[MAX_PACKET_SIZE + 7]; if (size > MAX_PACKET_SIZE) { ws_send_error(client, request_id, "Packet too big: %d > %d", size, MAX_PACKET_SIZE); return; } buffer[0] = 'T'; buffer[1] = (channel >> 8) & 0xFF; buffer[2] = channel & 0xFF; buffer[3] = (reply_channel >> 8) & 0xFF; buffer[4] = reply_channel & 0xFF; buffer[5] = (size >> 8) & 0xFF; buffer[6] = size & 0xFF; memcpy(&buffer[7], data, size); for (struct mg_connection *conn = client->connection->mgr->conns; conn != NULL; conn = conn->next) { if (conn->is_websocket) { struct client *other_client = *(struct client **)&conn->data[0]; if (other_client->connection == conn && other_client->connection != client->connection) { if (client_is_open(other_client, channel)) { metrics.sent_bytes += size + 7; metrics.sent_messages++; mg_ws_send(other_client->connection, buffer, size + 7, WEBSOCKET_OP_BINARY); } } } } ws_respond(client, request_id, NULL, 0); } bool client_is_open(struct client *client, uint16_t channel) { for (int i = 0; i < client->next_open_channel_index; i++) { if (client->open_channels[i] == channel) { return true; } } return false; }