1
0
Fork 0
cc-stuff/augment/wsvpn.lua

155 lines
4.9 KiB
Lua

local expect = require("cc.expect")
local WSModem = {
open = function(self, channel)
expect.expect(1, channel, "number")
expect.range(channel, 0, 65535)
self._request(0x4f, {
bit.band(0xFF, bit.brshift(channel, 8)),
bit.band(0xFF, channel)
})
end,
isOpen = function(self, channel)
expect.expect(1, channel, "number")
expect.range(channel, 0, 65535)
return self._request(0x6f, {
bit.band(0xFF, bit.brshift(channel, 8)),
bit.band(0xFF, channel)
})[1] ~= 0
end,
close = function(self, channel)
expect.expect(1, channel, "number")
expect.range(channel, 0, 65535)
self._request(0x63, {
bit.band(0xFF, bit.brshift(channel, 8)),
bit.band(0xFF, channel)
})
end,
closeAll = function(self)
self._request(0x43)
end,
transmit = function(self, channel, replyChannel, data)
expect.expect(1, channel, "number")
expect.expect(2, replyChannel, "number")
expect.expect(3, data, "nil", "string", "number", "table")
expect.range(channel, 0, 65535)
expect.range(replyChannel, 0, 65535)
local serialized = textutils.serializeJSON(data)
expect.range(#serialized, 0, 65535)
serialized = { serialized:byte(1, 65536) }
self._request(0x54, {
bit.band(0xFF, bit.brshift(channel, 8)),
bit.band(0xFF, channel),
bit.band(0xFF, bit.brshift(replyChannel, 8)),
bit.band(0xFF, replyChannel),
bit.band(0xFF, bit.brshift(#serialized, 8)),
bit.band(0xFF, #serialized),
table.unpack(serialized, 1, #serialized)
})
end,
isWireless = function(self) return true end,
run = function(self)
while true do
local data, binary = self._socket.receive()
if not data then return true end
if binary == false then return false, "Not a binary message" end
data = { string.byte(data, 1, #data) }
local opcode = table.remove(data, 1)
if opcode == 0x49 then -- info
local len, msg = self._read_u16ne(data)
msg = string.char(table.unpack(msg))
os.queueEvent("wsvpn:info", msg)
elseif opcode == 0x41 then -- Set address/side
local len = table.remove(data, 1)
self.side = string.char(table.unpack(data, 1, len))
elseif opcode == 0x45 then -- Error
local request_id, error_length
request_id, data = self._read_u16ne(data)
error_length, data = self._read_u16ne(data)
local message = string.char(table.unpack(data, 1, error_length))
os.queueEvent("wsvpn:response", false, request_id, message)
elseif opcode == 0x52 then -- Response
local request_id, response = self._read_u16ne(data)
os.queueEvent("wsvpn:response", true, request_id, response)
elseif opcode == 0x54 then -- Transmission
local channel, replyChannel, dataSize, packet
channel, data = self._read_u16ne(data)
replyChannel, data = self._read_u16ne(data)
dataSize, packet = self._read_u16ne(data)
os.queueEvent("modem_message", self.side or "wsmodem_0", channel, replyChannel, textutils.unserializeJSON(string.char(table.unpack(data, 1, dataSize))), nil)
else
return false, string.format("Invalid opcode 0x%02x", opcode)
end
os.sleep(0)
end
end,
-- low-level part
_read_u16ne = function(self, data)
local v = bit.blshift(table.remove(data, 1), 8)
v = bit.bor(v, table.remove(data, 1))
return v, data
end,
_wait_response = function(self, request_id)
while true do
local ev, status, id, data = os.pullEvent("wsvpn:response")
if ev == "wsvpn:response" and id == request_id then
return status, data
end
end
end,
_request = function(self, opcode, data)
local request_id = self._get_id()
self._socket.send(
string.char(
opcode,
bit.band(0xFF, bit.brshift(request_id, 8)),
bit.band(0xFF, request_id),
table.unpack(data or {})
),
true
)
local status, response = self._wait_response(request_id)
if not status then
error(response)
end
return response
end,
_get_id = function(self)
self._req_id = bit.band(0xFFFF, self._req_id + 1)
return self._req_id
end,
_send_text = function(self, code, fmt, ...)
local msg = { fmt:format(...):byte(1, 1020) }
self._socket.send(
string.char(
code,
bit.band(0xFF, bit.brshift(#msg, 8)),
bit.band(0xFF, #msg),
table.unpack(msg, 1, #msg)
),
true
)
end,
_init = function(self)
self._send_text(0x49, "Hello! I'm computer %d", os.getComputerID())
end,
}
return function(addr)
local ws = assert(http.websocket(addr))
local sock = setmetatable({ _socket = ws, _req_id = 0, side = "wsmodem_unknown" }, { __index = WSModem })
for name, method in pairs(WSModem) do
sock[name] = function(...) return method(sock, ...) end
end
sock._init()
return sock
end